import torch from torchvision import transforms from torch.utils.data import Dataset from models import * # Constants RANDOM_SEED = 123 BATCH_SIZE = 8 NUM_EPOCHS = 150 WARMUP_EPOCHS = 5 LEARNING_RATE = 0.0001 STEP_SIZE = 10 GAMMA = 0.3 CUTMIX_ALPHA = 0.3 # DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") DEVICE = torch.device("cpu") NUM_PRINT = 100 TASK = 1 WARMUP_EPOCHS = 5 RAW_DATA_DIR = r"data/train/raw/Task " AUG_DATA_DIR = r"data/train/augmented/Task " EXTERNAL_DATA_DIR = r"data/train/external/Task " COMBINED_DATA_DIR = r"data/train/combined/Task " TEST_DATA_DIR = r"data/test/Task " TEMP_DATA_DIR = "data/temp/Task " NUM_CLASSES = 7 LABEL_SMOOTHING_EPSILON = 0.1 EARLY_STOPPING_PATIENCE = 20 CLASSES = [ "Alzheimer Disease", "Cerebral Palsy", "Dystonia", "Essential Tremor", "Healthy", "Huntington Disease", "Parkinson Disease", ] MODEL = EfficientNetB3WithNorm(num_classes=NUM_CLASSES) MODEL_SAVE_PATH = r"output/checkpoints/" + MODEL.__class__.__name__ + ".pth" preprocess = transforms.Compose( [ transforms.Resize((224, 224)), transforms.ToTensor(), # Convert to tensor transforms.Normalize(0.8289, 0.2006), ] ) # Custom dataset class class CustomDataset(Dataset): def __init__(self, dataset): self.data = dataset def __len__(self): return len(self.data) def __getitem__(self, idx): img, label = self.data[idx] return img, label