|
import os |
|
import copy |
|
from PIL import Image |
|
import numpy as np |
|
|
|
import torch |
|
import torch.utils.data as data |
|
from torchvision import transforms, datasets |
|
|
|
DATA_ROOTS = 'data' |
|
|
|
class MNIST(data.Dataset): |
|
def __init__(self, root=DATA_ROOTS, train=True, image_transforms=None): |
|
super().__init__() |
|
if not os.path.isdir(root): |
|
os.makedirs(root) |
|
self.image_transforms = image_transforms |
|
self.dataset = datasets.mnist.MNIST(root, train=train, download=True) |
|
|
|
def __getitem__(self, index): |
|
img, target = self.dataset.data[index], int(self.dataset.targets[index]) |
|
img = Image.fromarray(img.numpy(), mode='L').convert('RGB') |
|
if self.image_transforms is not None: |
|
img = self.image_transforms(img) |
|
return img, target |
|
|
|
def __len__(self): |
|
return len(self.dataset) |