Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 20 additions & 39 deletions torchvision/models/shufflenetv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,9 @@

import torch
import torch.nn as nn
from .utils import load_state_dict_from_url

__all__ = ['ShuffleNetV2', 'shufflenetv2',
'shufflenetv2_x0_5', 'shufflenetv2_x1_0',
'shufflenetv2_x1_5', 'shufflenetv2_x2_0']
__all__ = ['ShuffleNetV2', 'shufflenetv2_x0_5', 'shufflenetv2_x1_0', 'shufflenetv2_x1_5', 'shufflenetv2_x2_0']

model_urls = {
'shufflenetv2_x0.5':
Expand Down Expand Up @@ -85,16 +84,13 @@ def forward(self, x):


class ShuffleNetV2(nn.Module):
def __init__(self, num_classes=1000, width_mult=1):
def __init__(self, stage_out_channels, num_classes=1000):
super(ShuffleNetV2, self).__init__()

try:
self.stage_out_channels = self._getStages(float(width_mult))
except KeyError:
raise ValueError('width_mult {} is not supported'.format(width_mult))

self.stage_out_channels = stage_out_channels
input_channels = 3
output_channels = self.stage_out_channels[0]

self.conv1 = nn.Sequential(
nn.Conv2d(input_channels, output_channels, 3, 2, 1, bias=False),
nn.BatchNorm2d(output_channels),
Expand Down Expand Up @@ -134,47 +130,32 @@ def forward(self, x):
x = self.fc(x)
return x

@staticmethod
def _getStages(mult):
stages = {
'0.5': [24, 48, 96, 192, 1024],
'1.0': [24, 116, 232, 464, 1024],
'1.5': [24, 176, 352, 704, 1024],
'2.0': [24, 244, 488, 976, 2048],
}
return stages[str(mult)]


def shufflenetv2(pretrained=False, num_classes=1000, width_mult=1, **kwargs):
model = ShuffleNetV2(num_classes=num_classes, width_mult=width_mult)
def _shufflenetv2(arch, pretrained, progress, stage_out_channels, **kwargs):
model = ShuffleNetV2(stage_out_channels=stage_out_channels, **kwargs)

if pretrained:
# change width_mult to float
if isinstance(width_mult, int):
width_mult = float(width_mult)
model_type = ('_'.join([ShuffleNetV2.__name__, 'x' + str(width_mult)]))
try:
model_url = model_urls[model_type.lower()]
except KeyError:
raise ValueError('model {} is not support'.format(model_type))
model_url = model_urls[arch]
if model_url is None:
raise NotImplementedError('pretrained {} is not supported'.format(model_type))
model.load_state_dict(torch.utils.model_zoo.load_url(model_url))
raise NotImplementedError('pretrained {} is not supported as of now'.format(arch))
else:
state_dict = load_state_dict_from_url(model_urls, progress=progress)
model.load_state_dict(state_dict)

return model


def shufflenetv2_x0_5(pretrained=False, num_classes=1000, **kwargs):
return shufflenetv2(pretrained, num_classes, 0.5)
def shufflenetv2_x0_5(pretrained=False, progress=True, **kwargs):
return _shufflenetv2('shufflenetv2_x0.5', pretrained, progress, [24, 48, 96, 192, 1024], **kwargs)


def shufflenetv2_x1_0(pretrained=False, num_classes=1000, **kwargs):
return shufflenetv2(pretrained, num_classes, 1)
def shufflenetv2_x1_0(pretrained=False, progress=True, **kwargs):
return _shufflenetv2('shufflenetv2_x1.0', pretrained, progress, [24, 116, 232, 464, 1024], **kwargs)


def shufflenetv2_x1_5(pretrained=False, num_classes=1000, **kwargs):
return shufflenetv2(pretrained, num_classes, 1.5)
def shufflenetv2_x1_5(pretrained=False, progress=True, **kwargs):
return _shufflenetv2('shufflenetv2_x1.5', pretrained, progress, [24, 176, 352, 704, 1024], **kwargs)


def shufflenetv2_x2_0(pretrained=False, num_classes=1000, **kwargs):
return shufflenetv2(pretrained, num_classes, 2)
def shufflenetv2_x2_0(pretrained=False, progress=True, **kwargs):
return _shufflenetv2('shufflenetv2_x2.0', pretrained, progress, [24, 244, 488, 976, 2048], **kwargs)