Source code for torchpack.models.vision.mobilenetv2

from typing import List, Tuple, Union

import torch
from torch import nn

from torchpack.models.utils import make_divisible

__all__ = ['MobileNetV2', 'MobileBlockV2']


[docs]class MobileBlockV2(nn.Module): def __init__(self, in_channels: int, out_channels: int, kernel_size: Union[int, Tuple[int, int]], *, stride: int = 1, expansion: int = 1) -> None: super().__init__() self.in_channels = in_channels self.out_channels = out_channels self.kernel_size = kernel_size self.stride = stride self.expansion = expansion if expansion == 1: self.layers = nn.Sequential( nn.Conv2d(in_channels, in_channels, kernel_size, stride=stride, padding=kernel_size // 2, groups=in_channels, bias=False), nn.BatchNorm2d(in_channels), nn.ReLU6(inplace=True), nn.Conv2d(in_channels, out_channels, 1, bias=False), nn.BatchNorm2d(out_channels), ) else: mid_channels = round(in_channels * expansion) self.layers = nn.Sequential( nn.Conv2d(in_channels, mid_channels, 1, bias=False), nn.BatchNorm2d(mid_channels), nn.ReLU6(inplace=True), nn.Conv2d(mid_channels, mid_channels, kernel_size, stride=stride, padding=kernel_size // 2, groups=mid_channels, bias=False), nn.BatchNorm2d(mid_channels), nn.ReLU6(inplace=True), nn.Conv2d(mid_channels, out_channels, 1, bias=False), nn.BatchNorm2d(out_channels), )
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: if self.in_channels == self.out_channels and self.stride == 1: return x + self.layers(x) else: return self.layers(x)
[docs]class MobileNetV2(nn.Module): layers: List = [ 32, (1, 16, 1, 1), (6, 24, 2, 2), (6, 32, 3, 2), (6, 64, 4, 2), (6, 96, 3, 1), (6, 160, 3, 2), (6, 320, 1, 1), 1280 ] def __init__(self, *, in_channels: int = 3, num_classes: int = 1000, width_multiplier: float = 1) -> None: super().__init__() out_channels = make_divisible(self.layers[0] * width_multiplier, 8) layers = nn.ModuleList([ nn.Sequential( nn.Conv2d(in_channels, out_channels, 3, stride=2, padding=1, bias=False), nn.BatchNorm2d(out_channels), nn.ReLU6(inplace=True), ) ]) in_channels = out_channels for expansion, out_channels, num_blocks, strides in self.layers[1:-1]: out_channels = make_divisible(out_channels * width_multiplier, 8) for stride in [strides] + [1] * (num_blocks - 1): layers.append( MobileBlockV2(in_channels, out_channels, 3, stride=stride, expansion=expansion)) in_channels = out_channels out_channels = make_divisible(self.layers[-1] * width_multiplier, 8, min_value=1280) layers.append( nn.Sequential( nn.Conv2d(in_channels, out_channels, 1, bias=False), nn.BatchNorm2d(out_channels), nn.ReLU6(inplace=True), )) in_channels = out_channels self.features = nn.Sequential(*layers) self.classifier = nn.Linear(in_channels, num_classes) self.reset_parameters()
[docs] def reset_parameters(self) -> None: for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') if m.bias is not None: nn.init.zeros_(m.bias) if isinstance(m, nn.Linear): nn.init.normal_(m.weight, std=0.01) if m.bias is not None: nn.init.zeros_(m.bias)
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.features(x) x = x.mean([2, 3]) x = self.classifier(x) return x