class mnist(vision_dataloader):
def __init__(self, name='mnist', train_batch_size=64, test_batch_size=64):
super().__init__(name=name)
self.train_batch_size = train_batch_size
self.test_batch_size = test_batch_size
def load(self, cache_dir='./data/', *args, **kwargs):
transform = Compose([
# transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
Normalize((0.1307,), (0.3081,)),
torch.flatten
])
train_loader = DataLoader(
MNIST(root=cache_dir, train=True, download=True, transform=transform),
batch_size=self.train_batch_size, shuffle=True)
test_loader = DataLoader(
MNIST(root=cache_dir, train=False, download=True, transform=transform),
batch_size=self.test_batch_size, shuffle=False)
return {'train_loader': train_loader, 'test_loader': test_loader}