Source code for hetseq.data.mnist_dataset

from functools import lru_cache
from torchvision import transforms
from PIL import Image

import numpy as np
import torch
import torch.utils.data
from torch.utils.data.dataloader import default_collate


[docs]class MNISTDataset(torch.utils.data.Dataset): def __init__(self, path): self.data = None self.path = path self.read_data(self.path) self.transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ]) """ **YD** original read_data def read_data(self, path): self.data = torch.load(path) self._len = len(self.data[0]) self.image = self.data[0].unsqueeze(1).float() self.label = self.data[1].long() """ def read_data(self, path): self.data = torch.load(path) self._len = len(self.data[0]) self.image = self.data[0] self.label = self.data[1] # **YD** # print(self.data[0].shape, self.data[1].shape) # raise ValueError('debugging for data shape') """ **YD** original __getitem__ @lru_cache(maxsize=8) def __getitem__(self, index): # print(self.image.shape, self.data[1].shape) return [self.image[index, :, :, :], self.label[index]] """ @lru_cache(maxsize=8) def __getitem__(self, index): img, target = self.image[index], int(self.label[index]) img = Image.fromarray(img.numpy(), mode='L') img = self.transform(img) return img, target # return [self.image[index, :, :, :], self.label[index]] def __len__(self): return self._len def ordered_indices(self): """Return an ordered list of indices. Batches will be constructed based on this order.""" return np.arange(len(self)) def num_tokens(self, index: int): return 1 def collater(self, samples): # For now only supports datasets with same underlying collater implementations # print("samples", type(samples)) if len(samples) == 0: return None else: return default_collate(samples) def set_epoch(self, epoch): pass
if __name__ == '__main__': path = '/scratch365/yding4/mnist/MNIST/processed/training.pt' dataset = MNISTDataset(path) data = torch.load(path) print(len(dataset)) print(data[0].shape, data[1].shape)