Source code for torchpack.models.vision.mobilenetv1

from typing import List, Tuple, Union

import torch
from torch import nn

from torchpack.models.utils import make_divisible

__all__ = ['MobileNetV1', 'MobileBlockV1']


[docs]class MobileBlockV1(nn.Sequential): def __init__(self, in_channels: int, out_channels: int, kernel_size: Union[int, Tuple[int, int]], *, stride: int = 1) -> None: self.in_channels = in_channels self.out_channels = out_channels self.kernel_size = kernel_size self.stride = stride super().__init__( nn.Conv2d(in_channels, in_channels, kernel_size, stride=stride, padding=kernel_size // 2, groups=in_channels, bias=False), nn.BatchNorm2d(in_channels), nn.ReLU(inplace=True), nn.Conv2d(in_channels, out_channels, 1, bias=False), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True), )
[docs]class MobileNetV1(nn.Module): layers: List = [(32, 1, 2), (64, 1, 1), (128, 2, 2), (256, 2, 2), (512, 6, 2), (1024, 2, 2)] def __init__(self, *, in_channels: int = 3, num_classes: int = 1000, width_multiplier: float = 1) -> None: super().__init__() out_channels = make_divisible(self.layers[0] * width_multiplier, 8) layers = nn.ModuleList([ nn.Sequential( nn.Conv2d(in_channels, out_channels, 3, stride=2, padding=1, bias=False), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True), ) ]) in_channels = out_channels for out_channels, num_blocks, strides in self.layers[1:]: out_channels = make_divisible(out_channels * width_multiplier, 8) for stride in [strides] + [1] * (num_blocks - 1): layers.append( MobileBlockV1(in_channels, out_channels, 3, stride=stride)) in_channels = out_channels self.features = nn.Sequential(*layers) self.classifier = nn.Linear(in_channels, num_classes) self.reset_parameters()
[docs] def reset_parameters(self) -> None: for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') if m.bias is not None: nn.init.zeros_(m.bias) if isinstance(m, nn.Linear): nn.init.normal_(m.weight, std=0.01) if m.bias is not None: nn.init.zeros_(m.bias)
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.features(x) x = x.mean([2, 3]) x = self.classifier(x) return x