Source code for torchpack.datasets.vision.cifar

from typing import Any, Callable, Dict, Optional

from torchvision import datasets
from torchvision.transforms import (Compose, Normalize, RandomCrop,
                                    RandomHorizontalFlip, Resize, ToTensor)

from torchpack.datasets.dataset import Dataset

__all__ = ['CIFAR']


class CIFAR10Dataset(datasets.CIFAR10):
    def __init__(self,
                 *,
                 root: str,
                 split: str,
                 transform: Optional[Callable] = None,
                 target_transform: Optional[Callable] = None,
                 download: bool = True) -> None:
        super().__init__(root=root,
                         train=(split == 'train'),
                         transform=transform,
                         target_transform=target_transform,
                         download=download)

    def __getitem__(self, index: int) -> Dict[str, Any]:
        image, label = super().__getitem__(index)
        return {'image': image, 'class': label}


class CIFAR100Dataset(datasets.CIFAR100):
    def __init__(self,
                 *,
                 root: str,
                 split: str,
                 transform: Optional[Callable] = None,
                 target_transform: Optional[Callable] = None,
                 download: bool = True) -> None:
        super().__init__(root=root,
                         train=(split == 'train'),
                         transform=transform,
                         target_transform=target_transform,
                         download=download)

    def __getitem__(self, index: int) -> Dict[str, Any]:
        image, label = super().__getitem__(index)
        return {'image': image, 'class': label}


[docs]class CIFAR(Dataset): def __init__(self, *, root: str, num_classes: int = 10, transforms: Optional[Dict[str, Callable]] = None) -> None: if num_classes == 10: CIFARDataset = CIFAR10Dataset elif num_classes == 100: CIFARDataset = CIFAR100Dataset else: raise NotImplementedError(f'CIFAR-{num_classes} is not supported.') if transforms is None: transforms = dict() if 'train' not in transforms: transforms['train'] = Compose([ RandomCrop(32, padding=4), RandomHorizontalFlip(), ToTensor(), Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010]) ]) if 'test' not in transforms: transforms['test'] = Compose([ Resize(32), ToTensor(), Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010]) ]) super().__init__({ split: CIFARDataset(root=root, split=split, transform=transforms[split]) for split in ['train', 'test'] })