Spaces:
Runtime error
Runtime error
import os | |
import torch | |
from torchvision import transforms | |
from torch.utils.data import Dataset | |
from models import * | |
import torch.nn as nn | |
from torchvision.models import squeezenet1_0, SqueezeNet1_0_Weights | |
from torchvision.models import squeezenet1_0 | |
# Constants | |
RANDOM_SEED = 123 | |
BATCH_SIZE = 16 | |
NUM_EPOCHS = 40 | |
LEARNING_RATE = 5.488903014780378e-05 | |
STEP_SIZE = 10 | |
GAMMA = 0.3 | |
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
NUM_PRINT = 100 | |
TASK = 1 | |
RAW_DATA_DIR = r"data/train/raw/Task " | |
AUG_DATA_DIR = r"data/train/augmented/Task " | |
EXTERNAL_DATA_DIR = r"data/train/external/Task " | |
TEMP_DATA_DIR = "data/temp/" | |
NUM_CLASSES = 7 | |
EARLY_STOPPING_PATIENCE = 20 | |
CLASSES = ['Alzheimer Disease', 'Cerebral Palsy', 'Dystonia', 'Essential Tremor', 'Healthy', 'Huntington Disease', 'Parkinson Disease'] | |
MODEL_SAVE_PATH = "output/checkpoints/model.pth" | |
class SqueezeNet1_0WithDropout(nn.Module): | |
def __init__(self, num_classes=1000): | |
super(SqueezeNet1_0WithDropout, self).__init__() | |
squeezenet = squeezenet1_0(weights=SqueezeNet1_0_Weights) | |
self.features = squeezenet.features | |
self.classifier = nn.Sequential( | |
nn.Conv2d(512, num_classes, kernel_size=1), | |
nn.BatchNorm2d(num_classes), # add batch normalization | |
nn.ReLU(inplace=True), | |
nn.AdaptiveAvgPool2d((1, 1)) | |
) | |
def forward(self, x): | |
x = self.features(x) | |
x = self.classifier(x) | |
x = torch.flatten(x, 1) | |
return x | |
MODEL = SqueezeNet1_0WithDropout(num_classes=7) | |
print(CLASSES) | |
preprocess = transforms.Compose( | |
[ | |
transforms.Resize((64, 64)), # Resize images to 64x64 | |
transforms.ToTensor(), # Convert to tensor | |
# Normalize 3 channels | |
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), | |
] | |
) | |
# Custom dataset class | |
class CustomDataset(Dataset): | |
def __init__(self, dataset): | |
self.data = dataset | |
def __len__(self): | |
return len(self.data) | |
def __getitem__(self, idx): | |
img, label = self.data[idx] | |
return img, label |