Skip to content

cifar10

Bases: vision_dataloader

Source code in tinybig/data/vision_dataloader.py
class cifar10(vision_dataloader):

    def __init__(self, name='cifar10', train_batch_size=64, test_batch_size=64):
        super().__init__(name=name)
        self.train_batch_size = train_batch_size
        self.test_batch_size = test_batch_size

    def load(self, cache_dir='./data/', *args, **kwargs):
        transform = Compose([
            # transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
            torch.flatten
        ])

        train_loader = DataLoader(
            CIFAR10(root=cache_dir, train=True, download=True, transform=transform),
            batch_size=self.train_batch_size, shuffle=True)

        test_loader = DataLoader(
            CIFAR10(root=cache_dir, train=False, download=True, transform=transform),
            batch_size=self.test_batch_size, shuffle=False)

        classes = ('plane', 'car', 'bird', 'cat',
                   'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

        return {'train_loader': train_loader, 'test_loader': test_loader, 'classes': classes}