Spaces:
Runtime error
Runtime error
import os | |
import torch | |
from torchvision import transforms | |
from torch.utils.data import Dataset | |
from models import * | |
import torch.nn as nn | |
from torchvision.models import ( | |
squeezenet1_0, | |
SqueezeNet1_0_Weights, | |
mobilenet_v3_small, | |
MobileNet_V3_Small_Weights, | |
) | |
from torchvision.models import squeezenet1_0 | |
# Constants | |
RANDOM_SEED = 123 | |
BATCH_SIZE = 16 | |
NUM_EPOCHS = 40 | |
LEARNING_RATE = 5.488903014780378e-05 | |
STEP_SIZE = 10 | |
GAMMA = 0.3 | |
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
NUM_PRINT = 100 | |
TASK = 1 | |
RAW_DATA_DIR = r"data/train/raw/Task " | |
AUG_DATA_DIR = r"data/train/augmented/Task " | |
EXTERNAL_DATA_DIR = r"data/train/external/Task " | |
TEMP_DATA_DIR = "data/temp/" | |
NUM_CLASSES = 7 | |
EARLY_STOPPING_PATIENCE = 20 | |
CLASSES = [ | |
"Alzheimer Disease", | |
"Cerebral Palsy", | |
"Dystonia", | |
"Essential Tremor", | |
"Healthy", | |
"Huntington Disease", | |
"Parkinson Disease", | |
] | |
MODEL_SAVE_PATH = r"output/checkpoints/model.pth" | |
class SqueezeNet1_0WithDropout(nn.Module): | |
def __init__(self, num_classes=1000): | |
super(SqueezeNet1_0WithDropout, self).__init__() | |
squeezenet = squeezenet1_0(weights=SqueezeNet1_0_Weights.DEFAULT) | |
self.features = squeezenet.features | |
self.classifier = nn.Sequential( | |
nn.Conv2d(512, num_classes, kernel_size=1), | |
nn.BatchNorm2d(num_classes), # add batch normalization | |
nn.ReLU(inplace=True), | |
nn.AdaptiveAvgPool2d((1, 1)), | |
) | |
def forward(self, x): | |
x = self.features(x) | |
x = self.classifier(x) | |
x = torch.flatten(x, 1) | |
return x | |
# class ShuffleNetV2WithDropout(nn.Module): | |
# def __init__(self, num_classes=1000): | |
# super(ShuffleNetV2WithDropout, self).__init__() | |
# shufflenet = shufflenet_v2_x2_0(weights=ShuffleNet_V2_X2_0_Weights) | |
# self.features = shufflenet.features | |
# self.classifier = nn.Sequential( | |
# nn.Conv2d(1024, num_classes, kernel_size=1), | |
# nn.BatchNorm2d(num_classes), # add batch normalization | |
# nn.ReLU(inplace=True), | |
# nn.AdaptiveAvgPool2d((1, 1)) | |
# ) | |
# def forward(self, x): | |
# x = self.features(x) | |
# x = self.classifier(x) | |
# x = torch.flatten(x, 1) | |
# return x | |
class MobileNetV3SmallWithDropout(nn.Module): | |
def __init__(self, num_classes=1000): | |
super(MobileNetV3SmallWithDropout, self).__init__() | |
mobilenet = mobilenet_v3_small(weights=MobileNet_V3_Small_Weights) | |
self.features = mobilenet.features | |
self.classifier = nn.Sequential( | |
nn.Conv2d(576, num_classes, kernel_size=1), | |
nn.BatchNorm2d(num_classes), # add batch normalization | |
nn.ReLU(inplace=True), | |
nn.AdaptiveAvgPool2d((1, 1)), | |
) | |
def forward(self, x): | |
x = self.features(x) | |
x = self.classifier(x) | |
x = torch.flatten(x, 1) | |
return x | |
class ResNet18WithNorm(nn.Module): | |
def __init__(self, num_classes=1000): | |
super(ResNet18WithNorm, self).__init__() | |
resnet = resnet18(pretrained=False) | |
self.features = nn.Sequential( | |
*list(resnet.children())[:-2] | |
) # Remove last 2 layers (avgpool and fc) | |
self.classifier = nn.Sequential( | |
nn.AdaptiveAvgPool2d((1, 1)), | |
nn.Flatten(), | |
nn.Linear(512, num_classes), | |
nn.BatchNorm2d(num_classes), # Add batch normalization | |
) | |
def forward(self, x): | |
x = self.features(x) | |
x = self.classifier(x) | |
x = torch.flatten(x, 1) | |
return x | |
MODEL = SqueezeNet1_0WithDropout(num_classes=7) | |
print(CLASSES) | |
preprocess = transforms.Compose( | |
[ | |
transforms.Resize((64, 64)), # Resize images to 64x64 | |
transforms.ToTensor(), # Convert to tensor | |
# Normalize 3 channels | |
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), | |
] | |
) | |
# Custom dataset class | |
class CustomDataset(Dataset): | |
def __init__(self, dataset): | |
self.data = dataset | |
def __len__(self): | |
return len(self.data) | |
def __getitem__(self, idx): | |
img, label = self.data[idx] | |
return img, label | |