Spaces:
Runtime error
Runtime error
import matplotlib.pyplot as plt | |
from torch.optim.lr_scheduler import CosineAnnealingLR | |
import torch | |
import torch.nn as nn | |
from torchvision.datasets import ImageFolder | |
from torch.utils.data import DataLoader | |
from data_loader import load_data, load_test_data | |
from configs import * | |
import numpy as np | |
torch.cuda.empty_cache() | |
# | |
class MLP(nn.Module): | |
def __init__(self, num_classes, num_models): | |
super(MLP, self).__init__() | |
self.layers = nn.Sequential( | |
nn.Linear(num_classes * num_models, 1024), | |
nn.LayerNorm(1024), | |
nn.LeakyReLU(negative_slope=0.01, inplace=True), | |
nn.Dropout(0.8), | |
nn.Linear(1024, 2048), | |
nn.LeakyReLU(negative_slope=0.01, inplace=True), | |
nn.Dropout(0.5), | |
nn.Linear(2048, 2048), | |
nn.LeakyReLU(negative_slope=0.01, inplace=True), | |
nn.Dropout(0.5), | |
nn.Linear(2048, num_classes), | |
) | |
def forward(self, x): | |
x = x.view(x.size(0), -1) | |
x = self.layers(x) | |
return x | |
def mlp_meta(num_classes, num_models): | |
model = MLP(num_classes, num_models) | |
return model | |
# Hyperparameters | |
input_dim = 3 * 224 * 224 # Modify this based on your input size | |
hidden_dim = 256 | |
output_dim = NUM_CLASSES | |
# Create the data loaders using your data_loader functions50 | |
train_loader, val_loader = load_data(COMBINED_DATA_DIR + "1", preprocess, BATCH_SIZE) | |
test_loader = load_test_data("data/test/Task 1", preprocess, BATCH_SIZE) | |
model_paths = [ | |
"output/checkpoints/bestsqueezenetSE3.pth", | |
"output/checkpoints/EfficientNetB3WithDropout.pth", | |
"output/checkpoints/MobileNetV2WithDropout2.pth", | |
] | |
# Define a function to load pretrained models | |
def load_pretrained_model(path, model): | |
model.load_state_dict(torch.load(path)) | |
return model.to(DEVICE) | |
def rand_bbox(size, lam): | |
W = size[2] | |
H = size[3] | |
cut_rat = np.sqrt(1.0 - lam) | |
cut_w = np.int_(W * cut_rat) | |
cut_h = np.int_(H * cut_rat) | |
# uniform | |
cx = np.random.randint(W) | |
cy = np.random.randint(H) | |
bbx1 = np.clip(cx - cut_w // 2, 0, W) | |
bby1 = np.clip(cy - cut_h // 2, 0, H) | |
bbx2 = np.clip(cx + cut_w // 2, 0, W) | |
bby2 = np.clip(cy + cut_h // 2, 0, H) | |
return bbx1, bby1, bbx2, bby2 | |
def cutmix_data(input, target, alpha=1.0): | |
if alpha > 0: | |
lam = np.random.beta(alpha, alpha) | |
else: | |
lam = 1 | |
batch_size = input.size()[0] | |
index = torch.randperm(batch_size) | |
rand_index = torch.randperm(input.size()[0]) | |
bbx1, bby1, bbx2, bby2 = rand_bbox(input.size(), lam) | |
input[:, :, bbx1:bbx2, bby1:bby2] = input[rand_index, :, bbx1:bbx2, bby1:bby2] | |
lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (input.size()[-1] * input.size()[-2])) | |
targets_a = target | |
targets_b = target[rand_index] | |
return input, targets_a, targets_b, lam | |
def cutmix_criterion(criterion, outputs, targets_a, targets_b, lam): | |
return lam * criterion(outputs, targets_a) + (1 - lam) * criterion( | |
outputs, targets_b | |
) | |
# Load pretrained models | |
model1 = load_pretrained_model( | |
model_paths[0], SqueezeNet1_0WithSE(num_classes=NUM_CLASSES) | |
).to(DEVICE) | |
model2 = load_pretrained_model( | |
model_paths[1], EfficientNetB3WithDropout(num_classes=NUM_CLASSES) | |
).to(DEVICE) | |
model3 = load_pretrained_model( | |
model_paths[2], MobileNetV2WithDropout(num_classes=NUM_CLASSES) | |
).to(DEVICE) | |
models = [model1, model2, model3] | |
# Create the meta learner | |
meta_learner_model = mlp_meta(NUM_CLASSES, len(models)).to(DEVICE) | |
meta_optimizer = torch.optim.Adam(meta_learner_model.parameters(), lr=0.001) | |
meta_loss_fn = torch.nn.CrossEntropyLoss() | |
# Define the Cosine Annealing Learning Rate Scheduler | |
scheduler = CosineAnnealingLR( | |
meta_optimizer, T_max=700 | |
) # T_max is the number of epochs for the cosine annealing. | |
# Define loss function and optimizer for the meta learner | |
criterion = nn.CrossEntropyLoss().to(DEVICE) | |
# Record learning rate | |
lr_hist = [] | |
# Training loop | |
num_epochs = 160 | |
for epoch in range(num_epochs): | |
print("[Epoch: {}]".format(epoch + 1)) | |
print("Total number of batches: {}".format(len(train_loader))) | |
for batch_idx, data in enumerate(train_loader, 0): | |
print("Batch: {}".format(batch_idx + 1)) | |
inputs, labels = data | |
inputs, labels = inputs.to(DEVICE), labels.to(DEVICE) | |
inputs, targets_a, targets_b, lam = cutmix_data(inputs, labels, alpha=1) | |
# Forward pass through the three pretrained models | |
features1 = model1(inputs) | |
features2 = model2(inputs) | |
features3 = model3(inputs) | |
# Stack the features from the three models | |
stacked_features = torch.cat((features1, features2, features3), dim=1).to( | |
DEVICE | |
) | |
# Forward pass through the meta learner | |
meta_output = meta_learner_model(stacked_features) | |
# Compute the loss | |
loss = cutmix_criterion(criterion, meta_output, targets_a, targets_b, lam) | |
# Compute the accuracy | |
_, predicted = torch.max(meta_output, 1) | |
total = labels.size(0) | |
correct = (predicted == labels).sum().item() | |
# Backpropagation and optimization | |
meta_optimizer.zero_grad() | |
loss.backward() | |
meta_optimizer.step() | |
lr_hist.append(meta_optimizer.param_groups[0]["lr"]) | |
scheduler.step() | |
print("Train Loss: {}".format(loss.item())) | |
print("Train Accuracy: {}%".format(100 * correct / total)) | |
# Validation | |
meta_learner_model.eval() | |
correct = 0 | |
total = 0 | |
val_loss = 0 | |
with torch.no_grad(): | |
for data in val_loader: | |
inputs, labels = data | |
inputs, labels = inputs.to(DEVICE), labels.to(DEVICE) | |
features1 = model1(inputs) | |
features2 = model2(inputs) | |
features3 = model3(inputs) | |
stacked_features = torch.cat((features1, features2, features3), dim=1).to( | |
DEVICE | |
) | |
outputs = meta_learner_model(stacked_features) | |
loss = criterion(outputs, labels) # Use the validation loss | |
val_loss += loss.item() | |
_, predicted = torch.max(outputs, 1) | |
total += labels.size(0) | |
correct += (predicted == labels).sum().item() | |
print( | |
"Validation Loss: {}".format(val_loss / len(val_loader)) | |
) # Calculate the average loss | |
print("Validation Accuracy: {}%".format(100 * correct / total)) | |
print("Finished Training") | |
# Test the ensemble | |
print("Testing the ensemble") | |
meta_learner_model.eval() | |
correct = 0 | |
total = 0 | |
with torch.no_grad(): | |
for data in test_loader: | |
inputs, labels = data | |
inputs, labels = inputs.to(DEVICE), labels.to(DEVICE) | |
features1 = model1(inputs) | |
features2 = model2(inputs) | |
features3 = model3(inputs) | |
stacked_features = torch.cat((features1, features2, features3), dim=1) | |
outputs = meta_learner_model(stacked_features) | |
_, predicted = torch.max(outputs, 1) | |
total += labels.size(0) | |
correct += (predicted == labels).sum().item() | |
print( | |
"Accuracy of the ensemble network on the test images: {}%".format( | |
100 * correct / total | |
) | |
) | |
# Plot the learning rate history | |
plt.plot(lr_hist) | |
plt.xlabel("Iterations") | |
plt.ylabel("Learning Rate") | |
plt.title("Learning Rate History") | |
plt.show() | |
# Save the model | |
torch.save(meta_learner_model.state_dict(), "output/checkpoints/ensemble.pth") |