From 69c3a1f86e1456851808ee22e8de6f9157b6f812 Mon Sep 17 00:00:00 2001 From: barrh Date: Tue, 7 May 2019 16:42:21 +0300 Subject: [PATCH 1/2] Enhance ShufflenetV2 Class shufflenetv2 receives `stages_repeats` and `stages_out_channels` arguments. --- torchvision/models/shufflenetv2.py | 37 +++++++++++++++++++----------- 1 file changed, 24 insertions(+), 13 deletions(-) diff --git a/torchvision/models/shufflenetv2.py b/torchvision/models/shufflenetv2.py index 5726cea2a22..a1cf2fd353b 100644 --- a/torchvision/models/shufflenetv2.py +++ b/torchvision/models/shufflenetv2.py @@ -84,13 +84,17 @@ def forward(self, x): class ShuffleNetV2(nn.Module): - def __init__(self, stage_out_channels, num_classes=1000): + def __init__(self, stages_repeats, stages_out_channels, num_classes=1000): super(ShuffleNetV2, self).__init__() - self.stage_out_channels = stage_out_channels - input_channels = 3 - output_channels = self.stage_out_channels[0] + if len(stages_repeats) != 3: + raise ValueError('expected stages_repeats as list of 3 positive ints') + if len(stages_out_channels) != 5: + raise ValueError('expected stages_out_channels as list of 5 positive ints') + self._stage_out_channels = stages_out_channels + input_channels = 3 + output_channels = self._stage_out_channels[0] self.conv1 = nn.Sequential( nn.Conv2d(input_channels, output_channels, 3, 2, 1, bias=False), nn.BatchNorm2d(output_channels), @@ -101,16 +105,15 @@ def __init__(self, stage_out_channels, num_classes=1000): self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) stage_names = ['stage{}'.format(i) for i in [2, 3, 4]] - stage_repeats = [4, 8, 4] for name, repeats, output_channels in zip( - stage_names, stage_repeats, self.stage_out_channels[1:]): + stage_names, stages_repeats, self._stage_out_channels[1:]): seq = [InvertedResidual(input_channels, output_channels, 2)] for i in range(repeats - 1): seq.append(InvertedResidual(output_channels, output_channels, 1)) setattr(self, name, nn.Sequential(*seq)) input_channels = output_channels - output_channels = self.stage_out_channels[-1] + output_channels = self._stage_out_channels[-1] self.conv5 = nn.Sequential( nn.Conv2d(input_channels, output_channels, 1, 1, 0, bias=False), nn.BatchNorm2d(output_channels), @@ -131,8 +134,8 @@ def forward(self, x): return x -def _shufflenetv2(arch, pretrained, progress, stage_out_channels, **kwargs): - model = ShuffleNetV2(stage_out_channels=stage_out_channels, **kwargs) +def _shufflenetv2(arch, pretrained, progress, *args, **kwargs): + model = ShuffleNetV2(*args, **kwargs) if pretrained: model_url = model_urls[arch] @@ -146,16 +149,24 @@ def _shufflenetv2(arch, pretrained, progress, stage_out_channels, **kwargs): def shufflenetv2_x0_5(pretrained=False, progress=True, **kwargs): - return _shufflenetv2('shufflenetv2_x0.5', pretrained, progress, [24, 48, 96, 192, 1024], **kwargs) + return _shufflenetv2('shufflenetv2_x0.5', pretrained, progress, + [4, 8, 4], [24, 48, 96, 192, 1024], + num_classes=1000, **kwargs) def shufflenetv2_x1_0(pretrained=False, progress=True, **kwargs): - return _shufflenetv2('shufflenetv2_x1.0', pretrained, progress, [24, 116, 232, 464, 1024], **kwargs) + return _shufflenetv2('shufflenetv2_x1.0', pretrained, progress, + [4, 8, 4], [24, 116, 232, 464, 1024], + num_classes=1000, **kwargs) def shufflenetv2_x1_5(pretrained=False, progress=True, **kwargs): - return _shufflenetv2('shufflenetv2_x1.5', pretrained, progress, [24, 176, 352, 704, 1024], **kwargs) + return _shufflenetv2('shufflenetv2_x1.5', pretrained, progress, + [4, 8, 4], [24, 176, 352, 704, 1024], + num_classes=1000, **kwargs) def shufflenetv2_x2_0(pretrained=False, progress=True, **kwargs): - return _shufflenetv2('shufflenetv2_x2.0', pretrained, progress, [24, 244, 488, 976, 2048], **kwargs) + return _shufflenetv2('shufflenetv2_x2.0', pretrained, progress, + [4, 8, 4], [24, 244, 488, 976, 2048], + num_classes=1000, **kwargs) From 93ab83e9bfe77e7336b50abc0424cebb29b47769 Mon Sep 17 00:00:00 2001 From: barrh Date: Tue, 7 May 2019 21:27:33 +0300 Subject: [PATCH 2/2] remove explicit num_classes argument from utility functions --- torchvision/models/shufflenetv2.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/torchvision/models/shufflenetv2.py b/torchvision/models/shufflenetv2.py index a1cf2fd353b..c56bab30bbd 100644 --- a/torchvision/models/shufflenetv2.py +++ b/torchvision/models/shufflenetv2.py @@ -150,23 +150,19 @@ def _shufflenetv2(arch, pretrained, progress, *args, **kwargs): def shufflenetv2_x0_5(pretrained=False, progress=True, **kwargs): return _shufflenetv2('shufflenetv2_x0.5', pretrained, progress, - [4, 8, 4], [24, 48, 96, 192, 1024], - num_classes=1000, **kwargs) + [4, 8, 4], [24, 48, 96, 192, 1024], **kwargs) def shufflenetv2_x1_0(pretrained=False, progress=True, **kwargs): return _shufflenetv2('shufflenetv2_x1.0', pretrained, progress, - [4, 8, 4], [24, 116, 232, 464, 1024], - num_classes=1000, **kwargs) + [4, 8, 4], [24, 116, 232, 464, 1024], **kwargs) def shufflenetv2_x1_5(pretrained=False, progress=True, **kwargs): return _shufflenetv2('shufflenetv2_x1.5', pretrained, progress, - [4, 8, 4], [24, 176, 352, 704, 1024], - num_classes=1000, **kwargs) + [4, 8, 4], [24, 176, 352, 704, 1024], **kwargs) def shufflenetv2_x2_0(pretrained=False, progress=True, **kwargs): return _shufflenetv2('shufflenetv2_x2.0', pretrained, progress, - [4, 8, 4], [24, 244, 488, 976, 2048], - num_classes=1000, **kwargs) + [4, 8, 4], [24, 244, 488, 976, 2048], **kwargs)