from configs import * from torchvision.datasets import ImageFolder from torch.utils.data import random_split, DataLoader, Dataset import torch torch.manual_seed(RANDOM_SEED) # Set seed torch.manual_seed(RANDOM_SEED) def load_data(combined_dir, preprocess, batch_size=BATCH_SIZE): dataset = ImageFolder(combined_dir, transform=preprocess) # Classes classes = dataset.classes print("Classes: ", *classes, sep=", ") print("Length of total dataset: ", len(dataset)) # Split the dataset into train and validation sets train_size = int(0.8 * len(dataset)) val_size = len(dataset) - train_size train_dataset, val_dataset = random_split(dataset, [train_size, val_size]) # Create data loaders for the custom dataset train_loader = DataLoader( CustomDataset(train_dataset), batch_size=batch_size, shuffle=True, num_workers=0 ) valid_loader = DataLoader( CustomDataset(val_dataset), batch_size=batch_size, num_workers=0, shuffle=False ) return train_loader, valid_loader def load_test_data(test_dir, preprocess, batch_size=BATCH_SIZE): test_dataset = ImageFolder(test_dir, transform=preprocess) # Create a DataLoader for the test data test_dataloader = DataLoader( CustomDataset(test_dataset), batch_size=batch_size, shuffle=False, num_workers=0 ) return test_dataloader