Source code for torchpack.datasets.vision.imagenet

import warnings
from typing import Any, Callable, Dict, Optional

from torchvision import datasets
from torchvision.transforms import (CenterCrop, Compose, Normalize,
                                    RandomHorizontalFlip, RandomResizedCrop,
                                    Resize, ToTensor)

from torchpack.datasets.dataset import Dataset

__all__ = ['ImageNet']


class ImageNetDataset(datasets.ImageNet):
    def __init__(self,
                 *,
                 root: str,
                 split: str,
                 transform: Optional[Callable] = None,
                 target_transform: Optional[Callable] = None) -> None:
        super().__init__(root=root,
                         split=('train' if split == 'train' else 'val'),
                         transform=transform,
                         target_transform=target_transform)

    def __getitem__(self, index: int) -> Dict[str, Any]:
        with warnings.catch_warnings():
            warnings.filterwarnings('ignore')
            image, label = super().__getitem__(index)
        return {'image': image, 'class': label}


[docs]class ImageNet(Dataset): def __init__(self, *, root: str, num_classes: int = 1000, transforms: Optional[Dict[str, Callable]] = None) -> None: if transforms is None: transforms = dict() if 'train' not in transforms: transforms['train'] = Compose([ RandomResizedCrop(224), RandomHorizontalFlip(), ToTensor(), Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) if 'test' not in transforms: transforms['test'] = Compose([ Resize(256), CenterCrop(224), ToTensor(), Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) super().__init__({ split: ImageNetDataset(root=root, split=split, transform=transforms[split]) for split in ['train', 'test'] }) indices = dict() for k in range(num_classes): indices[k * (1000 // num_classes)] = k for dataset in self.values(): samples = [] for x, c in dataset.samples: if c in indices: samples.append((x, indices[c])) dataset.samples = samples targets = [] for c in dataset.targets: if c in indices: targets.append(indices[c]) dataset.targets = targets classes = [] for c, x in enumerate(dataset.classes): if c in indices: classes.append(x) dataset.classes = classes class_to_idx = {} for x, c in dataset.class_to_idx.items(): if c in indices: class_to_idx[x] = c dataset.class_to_idx = class_to_idx