import torch from torchvision import datasets, transforms class MNISTDataModule: def __init__(self, batch_size=64, val_batch_size=1000): self.batch_size = batch_size self.val_batch_size = val_batch_size def get_dataloaders(self): """Create training and test dataloaders.""" transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)) ]) train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True) test_dataset = datasets.MNIST(root='./data', train=False, transform=transform, download=True) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True) test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=self.val_batch_size, shuffle=False) return train_loader, test_loader