wgcban's picture
Upload 98 files
803ef9e
raw
history blame contribute delete
857 Bytes
import os
import copy
from PIL import Image
import numpy as np
import torch
import torch.utils.data as data
from torchvision import transforms, datasets
DATA_ROOTS = 'data'
class MNIST(data.Dataset):
def __init__(self, root=DATA_ROOTS, train=True, image_transforms=None):
super().__init__()
if not os.path.isdir(root):
os.makedirs(root)
self.image_transforms = image_transforms
self.dataset = datasets.mnist.MNIST(root, train=train, download=True)
def __getitem__(self, index):
img, target = self.dataset.data[index], int(self.dataset.targets[index])
img = Image.fromarray(img.numpy(), mode='L').convert('RGB')
if self.image_transforms is not None:
img = self.image_transforms(img)
return img, target
def __len__(self):
return len(self.dataset)