From bc0020dadd3a9e36937bdff04b393b95a6762a2c Mon Sep 17 00:00:00 2001 From: Sergey Kolchenko Date: Thu, 26 Nov 2020 16:31:46 -0600 Subject: [PATCH 1/6] add regnet, res2net, resnest, sknet encoders --- .../encoders/__init__.py | 9 +- .../encoders/timm_regnet.py | 332 ++++++++++++++++++ .../encoders/timm_res2net.py | 160 +++++++++ .../encoders/timm_resnest.py | 205 +++++++++++ .../encoders/timm_sknet.py | 103 ++++++ 5 files changed, 808 insertions(+), 1 deletion(-) create mode 100644 segmentation_models_pytorch/encoders/timm_regnet.py create mode 100644 segmentation_models_pytorch/encoders/timm_res2net.py create mode 100644 segmentation_models_pytorch/encoders/timm_resnest.py create mode 100644 segmentation_models_pytorch/encoders/timm_sknet.py diff --git a/segmentation_models_pytorch/encoders/__init__.py b/segmentation_models_pytorch/encoders/__init__.py index e9cd7afb..a409a662 100644 --- a/segmentation_models_pytorch/encoders/__init__.py +++ b/segmentation_models_pytorch/encoders/__init__.py @@ -12,7 +12,10 @@ from .mobilenet import mobilenet_encoders from .xception import xception_encoders from .timm_efficientnet import timm_efficientnet_encoders - +from .timm_resnest import timm_resnest_encoders +from .timm_res2net import timm_res2net_encoders +from .timm_regnet import timm_regnet_encoders +from .timm_sknet import timm_sknet_encoders from ._preprocessing import preprocess_input encoders = {} @@ -27,6 +30,10 @@ encoders.update(mobilenet_encoders) encoders.update(xception_encoders) encoders.update(timm_efficientnet_encoders) +encoders.update(timm_resnest_encoders) +encoders.update(timm_res2net_encoders) +encoders.update(timm_regnet_encoders) +encoders.update(timm_sknet_encoders) def get_encoder(name, in_channels=3, depth=5, weights=None): diff --git a/segmentation_models_pytorch/encoders/timm_regnet.py b/segmentation_models_pytorch/encoders/timm_regnet.py new file mode 100644 index 00000000..e02ad59b --- /dev/null +++ b/segmentation_models_pytorch/encoders/timm_regnet.py @@ -0,0 +1,332 @@ +from ._base import EncoderMixin +from timm.models.regnet import RegNet +import torch.nn as nn + + +class RegNetEncoder(RegNet, EncoderMixin): + def __init__(self, out_channels, depth=5, **kwargs): + super().__init__(**kwargs) + self._depth = depth + self._out_channels = out_channels + self._in_channels = 3 + + del self.head + + def get_stages(self): + return [ + nn.Identity(), + self.stem, + self.s1, + self.s2, + self.s3, + self.s4, + ] + + def forward(self, x): + stages = self.get_stages() + + features = [] + for i in range(self._depth + 1): + x = stages[i](x) + features.append(x) + + return features + + def load_state_dict(self, state_dict, **kwargs): + state_dict.pop("head.fc.weight") + state_dict.pop("head.fc.bias") + super().load_state_dict(state_dict, **kwargs) + + +regnet_weights = { + 'timm-regnetx_002': { + 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_002-e7e85e5c.pth', + }, + 'timm-regnetx_004': { + 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_004-7d0e9424.pth', + }, + 'timm-regnetx_006': { + 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_006-85ec1baa.pth', + }, + 'timm-regnetx_008': { + 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_008-d8b470eb.pth', + }, + 'timm-regnetx_016': { + 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_016-65ca972a.pth', + }, + 'timm-regnetx_032': { + 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_032-ed0c7f7e.pth', + }, + 'timm-regnetx_040': { + 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_040-73c2a654.pth', + }, + 'timm-regnetx_064': { + 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_064-29278baa.pth', + }, + 'timm-regnetx_080': { + 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_080-7c7fcab1.pth', + }, + 'timm-regnetx_120': { + 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_120-65d5521e.pth', + }, + 'timm-regnetx_160': { + 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_160-c98c4112.pth', + }, + 'timm-regnetx_320': { + 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_320-8ea38b93.pth', + }, + 'timm-regnety_002': { + 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_002-e68ca334.pth', + }, + 'timm-regnety_004': { + 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_004-0db870e6.pth', + }, + 'timm-regnety_006': { + 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_006-c67e57ec.pth', + }, + 'timm-regnety_008': { + 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_008-dc900dbe.pth', + }, + 'timm-regnety_016': { + 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_016-54367f74.pth', + }, + 'timm-regnety_032': { + 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/regnety_032_ra-7f2439f9.pth' + }, + 'timm-regnety_040': { + 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_040-f0d569f9.pth' + }, + 'timm-regnety_064': { + 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_064-0a48325c.pth' + }, + 'timm-regnety_080': { + 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_080-e7f3eb93.pth', + }, + 'timm-regnety_120': { + 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_120-721ba79a.pth', + }, + 'timm-regnety_160': { + 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_160-d64013cd.pth', + }, + 'timm-regnety_320': { + 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_320-ba464b29.pth' + } +} + +pretrained_settings = {} +for model_name, sources in regnet_weights.items(): + pretrained_settings[model_name] = {} + for source_name, source_url in sources.items(): + pretrained_settings[model_name][source_name] = { + "url": source_url, + 'input_size': [3, 224, 224], + 'input_range': [0, 1], + 'mean': [0.485, 0.456, 0.406], + 'std': [0.229, 0.224, 0.225], + 'num_classes': 1000 + } + +# at this point I am too lazy to copy configs, so I just used the same configs from timm's repo + + +def _mcfg(**kwargs): + cfg = dict(se_ratio=0., bottle_ratio=1., stem_width=32) + cfg.update(**kwargs) + return cfg + + +timm_regnet_encoders = { + 'timm-regnetx_002': { + 'encoder': RegNetEncoder, + "pretrained_settings": pretrained_settings["timm-regnetx_002"], + 'params': { + 'out_channels': (3, 32, 24, 56, 152, 368), + 'cfg': _mcfg(w0=24, wa=36.44, wm=2.49, group_w=8, depth=13) + }, + }, + 'timm-regnetx_004': { + 'encoder': RegNetEncoder, + "pretrained_settings": pretrained_settings["timm-regnetx_004"], + 'params': { + 'out_channels': (3, 32, 32, 64, 160, 384), + 'cfg': _mcfg(w0=24, wa=24.48, wm=2.54, group_w=16, depth=22) + }, + }, + 'timm-regnetx_006': { + 'encoder': RegNetEncoder, + "pretrained_settings": pretrained_settings["timm-regnetx_006"], + 'params': { + 'out_channels': (3, 32, 48, 96, 240, 528), + 'cfg': _mcfg(w0=48, wa=36.97, wm=2.24, group_w=24, depth=16) + }, + }, + 'timm-regnetx_008': { + 'encoder': RegNetEncoder, + "pretrained_settings": pretrained_settings["timm-regnetx_008"], + 'params': { + 'out_channels': (3, 32, 64, 128, 288, 672), + 'cfg': _mcfg(w0=56, wa=35.73, wm=2.28, group_w=16, depth=16) + }, + }, + 'timm-regnetx_016': { + 'encoder': RegNetEncoder, + "pretrained_settings": pretrained_settings["timm-regnetx_016"], + 'params': { + 'out_channels': (3, 32, 72, 168, 408, 912), + 'cfg': _mcfg(w0=80, wa=34.01, wm=2.25, group_w=24, depth=18) + }, + }, + 'timm-regnetx_032': { + 'encoder': RegNetEncoder, + "pretrained_settings": pretrained_settings["timm-regnetx_032"], + 'params': { + 'out_channels': (3, 32, 96, 192, 432, 1008), + 'cfg': _mcfg(w0=88, wa=26.31, wm=2.25, group_w=48, depth=25) + }, + }, + 'timm-regnetx_040': { + 'encoder': RegNetEncoder, + "pretrained_settings": pretrained_settings["timm-regnetx_040"], + 'params': { + 'out_channels': (3, 32, 80, 240, 560, 1360), + 'cfg': _mcfg(w0=96, wa=38.65, wm=2.43, group_w=40, depth=23) + }, + }, + 'timm-regnetx_064': { + 'encoder': RegNetEncoder, + "pretrained_settings": pretrained_settings["timm-regnetx_064"], + 'params': { + 'out_channels': (3, 32, 168, 392, 784, 1624), + 'cfg': _mcfg(w0=184, wa=60.83, wm=2.07, group_w=56, depth=17) + }, + }, + 'timm-regnetx_080': { + 'encoder': RegNetEncoder, + "pretrained_settings": pretrained_settings["timm-regnetx_080"], + 'params': { + 'out_channels': (3, 32, 80, 240, 720, 1920), + 'cfg': _mcfg(w0=80, wa=49.56, wm=2.88, group_w=120, depth=23) + }, + }, + 'timm-regnetx_120': { + 'encoder': RegNetEncoder, + "pretrained_settings": pretrained_settings["timm-regnetx_120"], + 'params': { + 'out_channels': (3, 32, 224, 448, 896, 2240), + 'cfg': _mcfg(w0=168, wa=73.36, wm=2.37, group_w=112, depth=19) + }, + }, + 'timm-regnetx_160': { + 'encoder': RegNetEncoder, + "pretrained_settings": pretrained_settings["timm-regnetx_160"], + 'params': { + 'out_channels': (3, 32, 256, 512, 896, 2048), + 'cfg': _mcfg(w0=216, wa=55.59, wm=2.1, group_w=128, depth=22) + }, + }, + 'timm-regnetx_320': { + 'encoder': RegNetEncoder, + "pretrained_settings": pretrained_settings["timm-regnetx_320"], + 'params': { + 'out_channels': (3, 32, 336, 672, 1344, 2520), + 'cfg': _mcfg(w0=320, wa=69.86, wm=2.0, group_w=168, depth=23) + }, + }, + #regnety + 'timm-regnety_002': { + 'encoder': RegNetEncoder, + "pretrained_settings": pretrained_settings["timm-regnety_002"], + 'params': { + 'out_channels': (3, 32, 24, 56, 152, 368), + 'cfg': _mcfg(w0=24, wa=36.44, wm=2.49, group_w=8, depth=13, se_ratio=0.25) + }, + }, + 'timm-regnety_004': { + 'encoder': RegNetEncoder, + "pretrained_settings": pretrained_settings["timm-regnety_004"], + 'params': { + 'out_channels': (3, 32, 48, 104, 208, 440), + 'cfg': _mcfg(w0=48, wa=27.89, wm=2.09, group_w=8, depth=16, se_ratio=0.25) + }, + }, + 'timm-regnety_006': { + 'encoder': RegNetEncoder, + "pretrained_settings": pretrained_settings["timm-regnety_006"], + 'params': { + 'out_channels': (3, 32, 48, 112, 256, 608), + 'cfg': _mcfg(w0=48, wa=32.54, wm=2.32, group_w=16, depth=15, se_ratio=0.25) + }, + }, + 'timm-regnety_008': { + 'encoder': RegNetEncoder, + "pretrained_settings": pretrained_settings["timm-regnety_008"], + 'params': { + 'out_channels': (3, 32, 64, 128, 320, 768), + 'cfg': _mcfg(w0=56, wa=38.84, wm=2.4, group_w=16, depth=14, se_ratio=0.25) + }, + }, + 'timm-regnety_016': { + 'encoder': RegNetEncoder, + "pretrained_settings": pretrained_settings["timm-regnety_016"], + 'params': { + 'out_channels': (3, 32, 48, 120, 336, 888), + 'cfg': _mcfg(w0=48, wa=20.71, wm=2.65, group_w=24, depth=27, se_ratio=0.25) + }, + }, + 'timm-regnety_032': { + 'encoder': RegNetEncoder, + "pretrained_settings": pretrained_settings["timm-regnety_032"], + 'params': { + 'out_channels': (3, 32, 72, 216, 576, 1512), + 'cfg': _mcfg(w0=80, wa=42.63, wm=2.66, group_w=24, depth=21, se_ratio=0.25) + }, + }, + 'timm-regnety_040': { + 'encoder': RegNetEncoder, + "pretrained_settings": pretrained_settings["timm-regnety_040"], + 'params': { + 'out_channels': (3, 32, 128, 192, 512, 1088), + 'cfg': _mcfg(w0=96, wa=31.41, wm=2.24, group_w=64, depth=22, se_ratio=0.25) + }, + }, + 'timm-regnety_064': { + 'encoder': RegNetEncoder, + "pretrained_settings": pretrained_settings["timm-regnety_064"], + 'params': { + 'out_channels': (3, 32, 144, 288, 576, 1296), + 'cfg': _mcfg(w0=112, wa=33.22, wm=2.27, group_w=72, depth=25, se_ratio=0.25) + }, + }, + 'timm-regnety_080': { + 'encoder': RegNetEncoder, + "pretrained_settings": pretrained_settings["timm-regnety_080"], + 'params': { + 'out_channels': (3, 32, 168, 448, 896, 2016), + 'cfg': _mcfg(w0=192, wa=76.82, wm=2.19, group_w=56, depth=17, se_ratio=0.25) + }, + }, + 'timm-regnety_120': { + 'encoder': RegNetEncoder, + "pretrained_settings": pretrained_settings["timm-regnety_120"], + 'params': { + 'out_channels': (3, 32, 224, 448, 896, 2240), + 'cfg': _mcfg(w0=168, wa=73.36, wm=2.37, group_w=112, depth=19, se_ratio=0.25) + }, + }, + 'timm-regnety_160': { + 'encoder': RegNetEncoder, + "pretrained_settings": pretrained_settings["timm-regnety_160"], + 'params': { + 'out_channels': (3, 32, 224, 448, 1232, 3024), + 'cfg': _mcfg(w0=200, wa=106.23, wm=2.48, group_w=112, depth=18, se_ratio=0.25) + }, + }, + 'timm-regnety_320': { + 'encoder': RegNetEncoder, + "pretrained_settings": pretrained_settings["timm-regnety_320"], + 'params': { + 'out_channels': (3, 32, 232, 696, 1392, 3712), + 'cfg': _mcfg(w0=232, wa=115.89, wm=2.53, group_w=232, depth=20, se_ratio=0.25) + }, + }, +} diff --git a/segmentation_models_pytorch/encoders/timm_res2net.py b/segmentation_models_pytorch/encoders/timm_res2net.py new file mode 100644 index 00000000..a2f32785 --- /dev/null +++ b/segmentation_models_pytorch/encoders/timm_res2net.py @@ -0,0 +1,160 @@ +from ._base import EncoderMixin +from timm.models.resnet import ResNet +from timm.models.res2net import Bottle2neck +import torch.nn as nn + + +class Res2NetEncoder(ResNet, EncoderMixin): + def __init__(self, out_channels, depth=5, **kwargs): + super().__init__(**kwargs) + self._depth = depth + self._out_channels = out_channels + self._in_channels = 3 + + del self.fc + del self.global_pool + + def get_stages(self): + return [ + nn.Identity(), + nn.Sequential(self.conv1, self.bn1, self.act1), + nn.Sequential(self.maxpool, self.layer1), + self.layer2, + self.layer3, + self.layer4, + ] + + def forward(self, x): + stages = self.get_stages() + + features = [] + for i in range(self._depth + 1): + x = stages[i](x) + features.append(x) + + return features + + def load_state_dict(self, state_dict, **kwargs): + state_dict.pop("fc.bias") + state_dict.pop("fc.weight") + super().load_state_dict(state_dict, **kwargs) + + +res2net_weights = { + 'timm-res2net50_26w_4s': { + 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2net50_26w_4s-06e79181.pth' + }, + 'timm-res2net50_48w_2s': { + 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2net50_48w_2s-afed724a.pth' + }, + 'timm-res2net50_14w_8s': { + 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2net50_14w_8s-6527dddc.pth', + }, + 'timm-res2net50_26w_6s': { + 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2net50_26w_6s-19041792.pth', + }, + 'timm-res2net50_26w_8s': { + 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2net50_26w_8s-2c7c9f12.pth', + }, + 'timm-res2net101_26w_4s': { + 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2net101_26w_4s-02a759a1.pth', + }, + 'timm-res2next50': { + 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2next50_4s-6ef7e7bf.pth', + } +} + +pretrained_settings = {} +for model_name, sources in res2net_weights.items(): + pretrained_settings[model_name] = {} + for source_name, source_url in sources.items(): + pretrained_settings[model_name][source_name] = { + "url": source_url, + 'input_size': [3, 224, 224], + 'input_range': [0, 1], + 'mean': [0.485, 0.456, 0.406], + 'std': [0.229, 0.224, 0.225], + 'num_classes': 1000 + } + + +timm_res2net_encoders = { + 'timm-res2net50_26w_4s': { + 'encoder': Res2NetEncoder, + "pretrained_settings": pretrained_settings["timm-res2net50_26w_4s"], + 'params': { + 'out_channels': (3, 64, 256, 512, 1024, 2048), + 'block': Bottle2neck, + 'layers': [3, 4, 6, 3], + 'base_width': 26, + 'block_args': {'scale': 4} + }, + }, + 'timm-res2net101_26w_4s': { + 'encoder': Res2NetEncoder, + "pretrained_settings": pretrained_settings["timm-res2net101_26w_4s"], + 'params': { + 'out_channels': (3, 64, 256, 512, 1024, 2048), + 'block': Bottle2neck, + 'layers': [3, 4, 23, 3], + 'base_width': 26, + 'block_args': {'scale': 4} + }, + }, + 'timm-res2net50_26w_6s': { + 'encoder': Res2NetEncoder, + "pretrained_settings": pretrained_settings["timm-res2net50_26w_6s"], + 'params': { + 'out_channels': (3, 64, 256, 512, 1024, 2048), + 'block': Bottle2neck, + 'layers': [3, 4, 6, 3], + 'base_width': 26, + 'block_args': {'scale': 6} + }, + }, + 'timm-res2net50_26w_8s': { + 'encoder': Res2NetEncoder, + "pretrained_settings": pretrained_settings["timm-res2net50_26w_8s"], + 'params': { + 'out_channels': (3, 64, 256, 512, 1024, 2048), + 'block': Bottle2neck, + 'layers': [3, 4, 6, 3], + 'base_width': 26, + 'block_args': {'scale': 8} + }, + }, + 'timm-res2net50_48w_2s': { + 'encoder': Res2NetEncoder, + "pretrained_settings": pretrained_settings["timm-res2net50_48w_2s"], + 'params': { + 'out_channels': (3, 64, 256, 512, 1024, 2048), + 'block': Bottle2neck, + 'layers': [3, 4, 6, 3], + 'base_width': 48, + 'block_args': {'scale': 2} + }, + }, + 'timm-res2net50_14w_8s': { + 'encoder': Res2NetEncoder, + "pretrained_settings": pretrained_settings["timm-res2net50_14w_8s"], + 'params': { + 'out_channels': (3, 64, 256, 512, 1024, 2048), + 'block': Bottle2neck, + 'layers': [3, 4, 6, 3], + 'base_width': 14, + 'block_args': {'scale': 8} + }, + }, + 'timm-res2next50': { + 'encoder': Res2NetEncoder, + "pretrained_settings": pretrained_settings["timm-res2next50"], + 'params': { + 'out_channels': (3, 64, 256, 512, 1024, 2048), + 'block': Bottle2neck, + 'layers': [3, 4, 6, 3], + 'base_width': 4, + 'cardinality': 8, + 'block_args': {'scale': 4} + }, + } +} diff --git a/segmentation_models_pytorch/encoders/timm_resnest.py b/segmentation_models_pytorch/encoders/timm_resnest.py new file mode 100644 index 00000000..3289f786 --- /dev/null +++ b/segmentation_models_pytorch/encoders/timm_resnest.py @@ -0,0 +1,205 @@ +from ._base import EncoderMixin +from timm.models.resnet import ResNet +from timm.models.resnest import ResNestBottleneck +import torch.nn as nn + + +class ResNestEncoder(ResNet, EncoderMixin): + def __init__(self, out_channels, depth=5, **kwargs): + super().__init__(**kwargs) + self._depth = depth + self._out_channels = out_channels + self._in_channels = 3 + + del self.fc + del self.global_pool + + def get_stages(self): + return [ + nn.Identity(), + nn.Sequential(self.conv1, self.bn1, self.act1), + nn.Sequential(self.maxpool, self.layer1), + self.layer2, + self.layer3, + self.layer4, + ] + + def forward(self, x): + stages = self.get_stages() + + features = [] + for i in range(self._depth + 1): + x = stages[i](x) + features.append(x) + + return features + + def load_state_dict(self, state_dict, **kwargs): + state_dict.pop("fc.bias") + state_dict.pop("fc.weight") + super().load_state_dict(state_dict, **kwargs) + + +resnest_weights = { + 'timm-resnest14d': { + 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/gluon_resnest14-9c8fe254.pth' + }, + 'timm-resnest26d': { + 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/gluon_resnest26-50eb607c.pth' + }, + 'timm-resnest50d': { + 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-resnest/resnest50-528c19ca.pth', + }, + 'timm-resnest101e': { + 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-resnest/resnest101-22405ba7.pth', + }, + 'timm-resnest200e': { + 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-resnest/resnest200-75117900.pth', + }, + 'timm-resnest269e': { + 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-resnest/resnest269-0cc87c48.pth', + }, + 'timm-resnest50d_4s2x40d': { + 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-resnest/resnest50_fast_4s2x40d-41d14ed0.pth', + }, + 'timm-resnest50d_1s4x24d': { + 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-resnest/resnest50_fast_1s4x24d-d4a4f76f.pth', + } +} + +pretrained_settings = {} +for model_name, sources in resnest_weights.items(): + pretrained_settings[model_name] = {} + for source_name, source_url in sources.items(): + pretrained_settings[model_name][source_name] = { + "url": source_url, + 'input_size': [3, 224, 224], + 'input_range': [0, 1], + 'mean': [0.485, 0.456, 0.406], + 'std': [0.229, 0.224, 0.225], + 'num_classes': 1000 + } + + +timm_resnest_encoders = { + 'timm-resnest14d': { + 'encoder': ResNestEncoder, + "pretrained_settings": pretrained_settings["timm-resnest14d"], + 'params': { + 'out_channels': (3, 64, 256, 512, 1024, 2048), + 'block': ResNestBottleneck, + 'layers': [1, 1, 1, 1], + 'stem_type': 'deep', + 'stem_width': 32, + 'avg_down': True, + 'base_width': 64, + 'cardinality': 1, + 'block_args': {'radix': 2, 'avd': True, 'avd_first': False} + } + }, + 'timm-resnest26d': { + 'encoder': ResNestEncoder, + "pretrained_settings": pretrained_settings["timm-resnest26d"], + 'params': { + 'out_channels': (3, 64, 256, 512, 1024, 2048), + 'block': ResNestBottleneck, + 'layers': [2, 2, 2, 2], + 'stem_type': 'deep', + 'stem_width': 32, + 'avg_down': True, + 'base_width': 64, + 'cardinality': 1, + 'block_args': {'radix': 2, 'avd': True, 'avd_first': False} + } + }, + 'timm-resnest50d': { + 'encoder': ResNestEncoder, + "pretrained_settings": pretrained_settings["timm-resnest50d"], + 'params': { + 'out_channels': (3, 64, 256, 512, 1024, 2048), + 'block': ResNestBottleneck, + 'layers': [3, 4, 6, 3], + 'stem_type': 'deep', + 'stem_width': 32, + 'avg_down': True, + 'base_width': 64, + 'cardinality': 1, + 'block_args': {'radix': 2, 'avd': True, 'avd_first': False} + } + }, + 'timm-resnest101e': { + 'encoder': ResNestEncoder, + "pretrained_settings": pretrained_settings["timm-resnest101e"], + 'params': { + 'out_channels': (3, 128, 256, 512, 1024, 2048), + 'block': ResNestBottleneck, + 'layers': [3, 4, 23, 3], + 'stem_type': 'deep', + 'stem_width': 64, + 'avg_down': True, + 'base_width': 64, + 'cardinality': 1, + 'block_args': {'radix': 2, 'avd': True, 'avd_first': False} + } + }, + 'timm-resnest200e': { + 'encoder': ResNestEncoder, + "pretrained_settings": pretrained_settings["timm-resnest200e"], + 'params': { + 'out_channels': (3, 128, 256, 512, 1024, 2048), + 'block': ResNestBottleneck, + 'layers': [3, 24, 36, 3], + 'stem_type': 'deep', + 'stem_width': 64, + 'avg_down': True, + 'base_width': 64, + 'cardinality': 1, + 'block_args': {'radix': 2, 'avd': True, 'avd_first': False} + } + }, + 'timm-resnest269e': { + 'encoder': ResNestEncoder, + "pretrained_settings": pretrained_settings["timm-resnest269e"], + 'params': { + 'out_channels': (3, 128, 256, 512, 1024, 2048), + 'block': ResNestBottleneck, + 'layers': [3, 30, 48, 8], + 'stem_type': 'deep', + 'stem_width': 64, + 'avg_down': True, + 'base_width': 64, + 'cardinality': 1, + 'block_args': {'radix': 2, 'avd': True, 'avd_first': False} + }, + }, + 'timm-resnest50d_4s2x40d': { + 'encoder': ResNestEncoder, + "pretrained_settings": pretrained_settings["timm-resnest50d_4s2x40d"], + 'params': { + 'out_channels': (3, 64, 256, 512, 1024, 2048), + 'block': ResNestBottleneck, + 'layers': [3, 4, 6, 3], + 'stem_type': 'deep', + 'stem_width': 32, + 'avg_down': True, + 'base_width': 40, + 'cardinality': 2, + 'block_args': {'radix': 4, 'avd': True, 'avd_first': True} + } + }, + 'timm-resnest50d_1s4x24d': { + 'encoder': ResNestEncoder, + "pretrained_settings": pretrained_settings["timm-resnest50d_1s4x24d"], + 'params': { + 'out_channels': (3, 64, 256, 512, 1024, 2048), + 'block': ResNestBottleneck, + 'layers': [3, 4, 6, 3], + 'stem_type': 'deep', + 'stem_width': 32, + 'avg_down': True, + 'base_width': 24, + 'cardinality': 4, + 'block_args': {'radix': 1, 'avd': True, 'avd_first': True} + } + } +} diff --git a/segmentation_models_pytorch/encoders/timm_sknet.py b/segmentation_models_pytorch/encoders/timm_sknet.py new file mode 100644 index 00000000..bfb7572d --- /dev/null +++ b/segmentation_models_pytorch/encoders/timm_sknet.py @@ -0,0 +1,103 @@ +from ._base import EncoderMixin +from timm.models.resnet import ResNet +from timm.models.sknet import SelectiveKernelBottleneck, SelectiveKernelBasic +import torch.nn as nn + + +class SkNetEncoder(ResNet, EncoderMixin): + def __init__(self, out_channels, depth=5, **kwargs): + super().__init__(**kwargs) + self._depth = depth + self._out_channels = out_channels + self._in_channels = 3 + + del self.fc + del self.global_pool + + def get_stages(self): + return [ + nn.Identity(), + nn.Sequential(self.conv1, self.bn1, self.act1), + nn.Sequential(self.maxpool, self.layer1), + self.layer2, + self.layer3, + self.layer4, + ] + + def forward(self, x): + stages = self.get_stages() + + features = [] + for i in range(self._depth + 1): + x = stages[i](x) + features.append(x) + + return features + + def load_state_dict(self, state_dict, **kwargs): + state_dict.pop("fc.bias") + state_dict.pop("fc.weight") + super().load_state_dict(state_dict, **kwargs) + + +sknet_weights = { + 'timm-skresnet18': { + 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/skresnet18_ra-4eec2804.pth' + }, + 'timm-skresnet34': { + 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/skresnet34_ra-bdc0ccde.pth' + }, + 'timm-skresnext50_32x4d': { + 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/skresnext50_ra-f40e40bf.pth', + } +} + +pretrained_settings = {} +for model_name, sources in sknet_weights.items(): + pretrained_settings[model_name] = {} + for source_name, source_url in sources.items(): + pretrained_settings[model_name][source_name] = { + "url": source_url, + 'input_size': [3, 224, 224], + 'input_range': [0, 1], + 'mean': [0.485, 0.456, 0.406], + 'std': [0.229, 0.224, 0.225], + 'num_classes': 1000 + } + +timm_sknet_encoders = { + 'timm-skresnet18': { + 'encoder': SkNetEncoder, + "pretrained_settings": pretrained_settings["timm-skresnet18"], + 'params': { + 'out_channels': (3, 64, 64, 128, 256, 512), + 'block': SelectiveKernelBasic, + 'layers': [2, 2, 2, 2], + 'zero_init_last_bn': False, + 'block_args': {'sk_kwargs': {'min_attn_channels': 16, 'attn_reduction': 8, 'split_input': True}} + } + }, + 'timm-skresnet34': { + 'encoder': SkNetEncoder, + "pretrained_settings": pretrained_settings["timm-skresnet34"], + 'params': { + 'out_channels': (3, 64, 64, 128, 256, 512), + 'block': SelectiveKernelBasic, + 'layers': [3, 4, 6, 3], + 'zero_init_last_bn': False, + 'block_args': {'sk_kwargs': {'min_attn_channels': 16, 'attn_reduction': 8, 'split_input': True}} + } + }, + 'timm-skresnext50_32x4d': { + 'encoder': SkNetEncoder, + "pretrained_settings": pretrained_settings["timm-skresnext50_32x4d"], + 'params': { + 'out_channels': (3, 64, 256, 512, 1024, 2048), + 'block': SelectiveKernelBottleneck, + 'layers': [3, 4, 6, 3], + 'zero_init_last_bn': False, + 'cardinality': 32, + 'base_width': 4 + } + } +} From 0a5599575720cb4e443021dc670aa780bf5921b2 Mon Sep 17 00:00:00 2001 From: Sergey Kolchenko Date: Thu, 26 Nov 2020 21:05:05 -0600 Subject: [PATCH 2/6] fixed tests --- tests/test_models.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/test_models.py b/tests/test_models.py index da6c3168..865f42f8 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -6,7 +6,6 @@ # mock detection module sys.modules["torchvision._C"] = mock.Mock() - import segmentation_models_pytorch as smp IS_TRAVIS = os.environ.get("TRAVIS", False) @@ -65,7 +64,7 @@ def test_forward(model_class, encoder_name, encoder_depth, **kwargs): encoder_name, encoder_depth=encoder_depth, encoder_weights=None, **kwargs ) sample = get_sample(model_class) - + model.eval() if encoder_depth == 5 and model_class != smp.PSPNet: test_shape = True else: @@ -111,6 +110,7 @@ def test_upsample(model_class, upsampling): def test_in_channels(model_class, encoder_name, in_channels): sample = torch.ones([1, in_channels, 64, 64]) model = model_class(DEFAULT_ENCODER, encoder_weights=None, in_channels=in_channels) + model.eval() with torch.no_grad(): model(sample) @@ -120,7 +120,8 @@ def test_in_channels(model_class, encoder_name, in_channels): @pytest.mark.parametrize("encoder_name", ENCODERS) def test_dilation(encoder_name): if (encoder_name in ['inceptionresnetv2', 'xception', 'inceptionv4'] or - encoder_name.startswith('vgg') or encoder_name.startswith('densenet')): + encoder_name.startswith('vgg') or encoder_name.startswith('densenet') or + encoder_name.startswith('timm-res')): return encoder = smp.encoders.get_encoder(encoder_name) From a32f1587b49c8dda4b60d78f253c901e55e8b8a7 Mon Sep 17 00:00:00 2001 From: Sergey Kolchenko Date: Fri, 27 Nov 2020 10:04:57 -0600 Subject: [PATCH 3/6] add raise erorr for dilated encoders in deeplabv3 and PAN --- .../deeplabv3/model.py | 63 +++++++++++++++++-- segmentation_models_pytorch/pan/model.py | 17 +++++ 2 files changed, 75 insertions(+), 5 deletions(-) diff --git a/segmentation_models_pytorch/deeplabv3/model.py b/segmentation_models_pytorch/deeplabv3/model.py index 1315859d..695fb0e8 100644 --- a/segmentation_models_pytorch/deeplabv3/model.py +++ b/segmentation_models_pytorch/deeplabv3/model.py @@ -34,7 +34,7 @@ class DeepLabV3(SegmentationModel): .. _DeeplabV3: https://arxiv.org/abs/1706.05587 """ - + def __init__( self, encoder_name: str = "resnet34", @@ -43,6 +43,7 @@ def __init__( decoder_channels: int = 256, in_channels: int = 3, classes: int = 1, + encoder_dilation: bool = True, activation: Optional[str] = None, upsampling: int = 8, aux_params: Optional[dict] = None, @@ -55,10 +56,28 @@ def __init__( depth=encoder_depth, weights=encoder_weights, ) - self.encoder.make_dilated( - stage_list=[4, 5], - dilation_list=[2, 4] - ) + if encoder_dilation: + assert encoder_name not in [ + 'timm-res2net50_26w_4s', + 'timm-res2net50_48w_2s', + 'timm-res2net50_14w_8s', + 'timm-res2net50_26w_6s', + 'timm-res2net50_26w_8s', + 'timm-res2net101_26w_4s', + 'timm-res2next50', + 'timm-resnest14d', + 'timm-resnest26d', + 'timm-resnest50d', + 'timm-resnest101e', + 'timm-resnest200e', + 'timm-resnest269e', + 'timm-resnest50d_4s2x40d', + 'timm-resnest50d_1s4x24d' + ], f'{encoder_name} is not supported for dilation' + self.encoder.make_dilated( + stage_list=[4, 5], + dilation_list=[2, 4] + ) self.decoder = DeepLabV3Decoder( in_channels=self.encoder.out_channels[-1], @@ -136,12 +155,46 @@ def __init__( ) if encoder_output_stride == 8: + assert encoder_name not in [ + 'timm-res2net50_26w_4s', + 'timm-res2net50_48w_2s', + 'timm-res2net50_14w_8s', + 'timm-res2net50_26w_6s', + 'timm-res2net50_26w_8s', + 'timm-res2net101_26w_4s', + 'timm-res2next50', + 'timm-resnest14d', + 'timm-resnest26d', + 'timm-resnest50d', + 'timm-resnest101e', + 'timm-resnest200e', + 'timm-resnest269e', + 'timm-resnest50d_4s2x40d', + 'timm-resnest50d_1s4x24d' + ], f'{encoder_name} is not supported for dilation' self.encoder.make_dilated( stage_list=[4, 5], dilation_list=[2, 4] ) elif encoder_output_stride == 16: + assert encoder_name not in [ + 'timm-res2net50_26w_4s', + 'timm-res2net50_48w_2s', + 'timm-res2net50_14w_8s', + 'timm-res2net50_26w_6s', + 'timm-res2net50_26w_8s', + 'timm-res2net101_26w_4s', + 'timm-res2next50', + 'timm-resnest14d', + 'timm-resnest26d', + 'timm-resnest50d', + 'timm-resnest101e', + 'timm-resnest200e', + 'timm-resnest269e', + 'timm-resnest50d_4s2x40d', + 'timm-resnest50d_1s4x24d' + ], f'{encoder_name} is not supported for dilation' self.encoder.make_dilated( stage_list=[5], dilation_list=[2] diff --git a/segmentation_models_pytorch/pan/model.py b/segmentation_models_pytorch/pan/model.py index f362a37d..40779eef 100644 --- a/segmentation_models_pytorch/pan/model.py +++ b/segmentation_models_pytorch/pan/model.py @@ -62,6 +62,23 @@ def __init__( ) if encoder_dilation: + assert encoder_name not in [ + 'timm-res2net50_26w_4s', + 'timm-res2net50_48w_2s', + 'timm-res2net50_14w_8s', + 'timm-res2net50_26w_6s', + 'timm-res2net50_26w_8s', + 'timm-res2net101_26w_4s', + 'timm-res2next50', + 'timm-resnest14d', + 'timm-resnest26d', + 'timm-resnest50d', + 'timm-resnest101e', + 'timm-resnest200e', + 'timm-resnest269e', + 'timm-resnest50d_4s2x40d', + 'timm-resnest50d_1s4x24d' + ], f'{encoder_name} is not supported for dilation' self.encoder.make_dilated( stage_list=[5], dilation_list=[2] From 41a4ead0e041961690825a25dcfc7d6c7c8a7ca9 Mon Sep 17 00:00:00 2001 From: Sergey Kolchenko Date: Tue, 1 Dec 2020 23:14:18 -0600 Subject: [PATCH 4/6] raise correct error for encoders without dilation --- .../deeplabv3/model.py | 61 ++----------------- .../encoders/timm_res2net.py | 3 + .../encoders/timm_resnest.py | 3 + 3 files changed, 10 insertions(+), 57 deletions(-) diff --git a/segmentation_models_pytorch/deeplabv3/model.py b/segmentation_models_pytorch/deeplabv3/model.py index 695fb0e8..33cf0718 100644 --- a/segmentation_models_pytorch/deeplabv3/model.py +++ b/segmentation_models_pytorch/deeplabv3/model.py @@ -43,7 +43,6 @@ def __init__( decoder_channels: int = 256, in_channels: int = 3, classes: int = 1, - encoder_dilation: bool = True, activation: Optional[str] = None, upsampling: int = 8, aux_params: Optional[dict] = None, @@ -56,28 +55,10 @@ def __init__( depth=encoder_depth, weights=encoder_weights, ) - if encoder_dilation: - assert encoder_name not in [ - 'timm-res2net50_26w_4s', - 'timm-res2net50_48w_2s', - 'timm-res2net50_14w_8s', - 'timm-res2net50_26w_6s', - 'timm-res2net50_26w_8s', - 'timm-res2net101_26w_4s', - 'timm-res2next50', - 'timm-resnest14d', - 'timm-resnest26d', - 'timm-resnest50d', - 'timm-resnest101e', - 'timm-resnest200e', - 'timm-resnest269e', - 'timm-resnest50d_4s2x40d', - 'timm-resnest50d_1s4x24d' - ], f'{encoder_name} is not supported for dilation' - self.encoder.make_dilated( - stage_list=[4, 5], - dilation_list=[2, 4] - ) + self.encoder.make_dilated( + stage_list=[4, 5], + dilation_list=[2, 4] + ) self.decoder = DeepLabV3Decoder( in_channels=self.encoder.out_channels[-1], @@ -155,46 +136,12 @@ def __init__( ) if encoder_output_stride == 8: - assert encoder_name not in [ - 'timm-res2net50_26w_4s', - 'timm-res2net50_48w_2s', - 'timm-res2net50_14w_8s', - 'timm-res2net50_26w_6s', - 'timm-res2net50_26w_8s', - 'timm-res2net101_26w_4s', - 'timm-res2next50', - 'timm-resnest14d', - 'timm-resnest26d', - 'timm-resnest50d', - 'timm-resnest101e', - 'timm-resnest200e', - 'timm-resnest269e', - 'timm-resnest50d_4s2x40d', - 'timm-resnest50d_1s4x24d' - ], f'{encoder_name} is not supported for dilation' self.encoder.make_dilated( stage_list=[4, 5], dilation_list=[2, 4] ) elif encoder_output_stride == 16: - assert encoder_name not in [ - 'timm-res2net50_26w_4s', - 'timm-res2net50_48w_2s', - 'timm-res2net50_14w_8s', - 'timm-res2net50_26w_6s', - 'timm-res2net50_26w_8s', - 'timm-res2net101_26w_4s', - 'timm-res2next50', - 'timm-resnest14d', - 'timm-resnest26d', - 'timm-resnest50d', - 'timm-resnest101e', - 'timm-resnest200e', - 'timm-resnest269e', - 'timm-resnest50d_4s2x40d', - 'timm-resnest50d_1s4x24d' - ], f'{encoder_name} is not supported for dilation' self.encoder.make_dilated( stage_list=[5], dilation_list=[2] diff --git a/segmentation_models_pytorch/encoders/timm_res2net.py b/segmentation_models_pytorch/encoders/timm_res2net.py index a2f32785..d3766b9d 100644 --- a/segmentation_models_pytorch/encoders/timm_res2net.py +++ b/segmentation_models_pytorch/encoders/timm_res2net.py @@ -24,6 +24,9 @@ def get_stages(self): self.layer4, ] + def make_dilated(self, stage_list, dilation_list): + raise ValueError("Res2Net encoders do not support dilated mode") + def forward(self, x): stages = self.get_stages() diff --git a/segmentation_models_pytorch/encoders/timm_resnest.py b/segmentation_models_pytorch/encoders/timm_resnest.py index 3289f786..77c558c9 100644 --- a/segmentation_models_pytorch/encoders/timm_resnest.py +++ b/segmentation_models_pytorch/encoders/timm_resnest.py @@ -24,6 +24,9 @@ def get_stages(self): self.layer4, ] + def make_dilated(self, stage_list, dilation_list): + raise ValueError("ResNest encoders do not support dilated mode") + def forward(self, x): stages = self.get_stages() From aa5b3922a03d2e471c9e3b6863b260a3fef0ee4d Mon Sep 17 00:00:00 2001 From: Sergey Kolchenko Date: Fri, 4 Dec 2020 09:13:12 -0600 Subject: [PATCH 5/6] update README --- README.md | 42 ++++++++++++++++++++++++ segmentation_models_pytorch/pan/model.py | 17 ---------- 2 files changed, 42 insertions(+), 17 deletions(-) diff --git a/README.md b/README.md index 88c572b2..7ace2c45 100644 --- a/README.md +++ b/README.md @@ -131,6 +131,48 @@ preprocess_input = get_preprocessing_fn('resnet18', pretrained='imagenet') |timm-efficientnet-b7 |imagenet
advprop
noisy-student|63M | |timm-efficientnet-b8 |imagenet
advprop |84M | |timm-efficientnet-l2 |noisy-student |474M | +|timm-resnest14d |imagenet |8M | +|timm-resnest26d |imagenet |15M | +|timm-resnest50d |imagenet |25M | +|timm-resnest101e |imagenet |46M | +|timm-resnest200e |imagenet |68M | +|timm-resnest269e |imagenet |108M | +|timm-resnest50d_4s2x40d |imagenet |28M | +|timm-resnest50d_1s4x24d |imagenet |23M | +|timm-res2net50_26w_4s |imagenet |23M | +|timm-res2net101_26w_4s |imagenet |43M | +|timm-res2net50_26w_6s |imagenet |35M | +|timm-res2net50_26w_8s |imagenet |46M | +|timm-res2net50_48w_2s |imagenet |23M | +|timm-res2net50_14w_8s |imagenet |23M | +|timm-res2next50 |imagenet |22M | +|timm-regnetx_002 |imagenet |2M | +|timm-regnetx_004 |imagenet |4M | +|timm-regnetx_006 |imagenet |5M | +|timm-regnetx_008 |imagenet |6M | +|timm-regnetx_016 |imagenet |8M | +|timm-regnetx_032 |imagenet |14M | +|timm-regnetx_040 |imagenet |20M | +|timm-regnetx_064 |imagenet |24M | +|timm-regnetx_080 |imagenet |37M | +|timm-regnetx_120 |imagenet |43M | +|timm-regnetx_160 |imagenet |52M | +|timm-regnetx_320 |imagenet |105M | +|timm-regnety_002 |imagenet |2M | +|timm-regnety_004 |imagenet |3M | +|timm-regnety_006 |imagenet |5M | +|timm-regnety_008 |imagenet |5M | +|timm-regnety_016 |imagenet |10M | +|timm-regnety_032 |imagenet |17M | +|timm-regnety_040 |imagenet |19M | +|timm-regnety_064 |imagenet |29M | +|timm-regnety_080 |imagenet |37M | +|timm-regnety_120 |imagenet |49M | +|timm-regnety_160 |imagenet |80M | +|timm-regnety_320 |imagenet |141M | +|timm-skresnet18 |imagenet |11M | +|timm-skresnet34 |imagenet |21M | +|timm-skresnext50_32x4d |imagenet |25M | \* `ssl`, `wsl` - semi-supervised and weakly-supervised learning on ImageNet ([repo](https://github.com/facebookresearch/semi-supervised-ImageNet1K-models)). diff --git a/segmentation_models_pytorch/pan/model.py b/segmentation_models_pytorch/pan/model.py index 40779eef..f362a37d 100644 --- a/segmentation_models_pytorch/pan/model.py +++ b/segmentation_models_pytorch/pan/model.py @@ -62,23 +62,6 @@ def __init__( ) if encoder_dilation: - assert encoder_name not in [ - 'timm-res2net50_26w_4s', - 'timm-res2net50_48w_2s', - 'timm-res2net50_14w_8s', - 'timm-res2net50_26w_6s', - 'timm-res2net50_26w_8s', - 'timm-res2net101_26w_4s', - 'timm-res2next50', - 'timm-resnest14d', - 'timm-resnest26d', - 'timm-resnest50d', - 'timm-resnest101e', - 'timm-resnest200e', - 'timm-resnest269e', - 'timm-resnest50d_4s2x40d', - 'timm-resnest50d_1s4x24d' - ], f'{encoder_name} is not supported for dilation' self.encoder.make_dilated( stage_list=[5], dilation_list=[2] From 51dad0f8c22a67cfe5a7935bfc15b8ae0641e590 Mon Sep 17 00:00:00 2001 From: Sergey Kolchenko Date: Fri, 4 Dec 2020 09:14:59 -0600 Subject: [PATCH 6/6] fix README --- README.md | 123 ++++++++++++++++++++++++++++++++++++------------------ 1 file changed, 82 insertions(+), 41 deletions(-) diff --git a/README.md b/README.md index 7ace2c45..62b78aa0 100644 --- a/README.md +++ b/README.md @@ -11,11 +11,11 @@ Segmentation based on [PyTorch](https://pytorch.org/).** The main features of this library are: - High level API (just two lines to create neural network) - - 7 models architectures for binary and multi class segmentation (including legendary Unet) + - 8 models architectures for binary and multi class segmentation (including legendary Unet) - 57 available encoders for each architecture - All encoders have pre-trained weights for faster and better convergence -### Table of content +### 📋 Table of content 1. [Quick start](#start) 2. [Examples](#examples) 3. [Models](#models) @@ -31,36 +31,42 @@ The main features of this library are: 8. [Citing](#citing) 9. [License](#license) -### Quick start -Since the library is built on the PyTorch framework, created segmentation model is just a PyTorch nn.Module, which can be created as easy as: -```python -import segmentation_models_pytorch as smp +### ⏳ Quick start -model = smp.Unet() -``` -Depending on the task, you can change the network architecture by choosing backbones with fewer or more parameters and use pretrainded weights to initialize it: +#### 1. Create your first Segmentation model with SMP + +Segmentation model is just a PyTorch nn.Module, which can be created as easy as: ```python -model = smp.Unet('resnet34', encoder_weights='imagenet') +import segmentation_models_pytorch as smp + +model = smp.Unet( + encoder_name="resnet34", # choose encoder, e.g. mobilenet_v2 or efficientnet-b7 + encoder_weights="imagenet", # use `imagenet` pretreined weights for encoder initialization + in_channels=1, # model input channels (1 for grayscale images, 3 for RGB, etc.) + classes=3, # model output channels (number of classes in your dataset) +) ``` + - see [table](#architectires) with available model architectures + - see [table](#encoders) with avaliable encoders and its corresponding weights -Change number of output classes in the model: +#### 2. Configure data preprocessing -```python -model = smp.Unet('resnet34', classes=3, activation='softmax') -``` +All encoders have pretrained weights. Preparing your data the same way as during weights pretraining may give your better results (higher metric score and faster convergence). But it is relevant only for 1-2-3-channels images and **not necessary** in case you train the whole model, not only decoder. -All models have pretrained encoders, so you have to prepare your data the same way as during weights pretraining: ```python from segmentation_models_pytorch.encoders import get_preprocessing_fn preprocess_input = get_preprocessing_fn('resnet18', pretrained='imagenet') ``` -### Examples + +Congratulations! You are done! Now you can train your model with your favorite framework! + +### 💡 Examples - Training model for cars segmentation on CamVid dataset [here](https://github.com/qubvel/segmentation_models.pytorch/blob/master/examples/cars%20segmentation%20(camvid).ipynb). - - Training SMP model with [Catalyst](https://github.com/catalyst-team/catalyst) (high-level framework for PyTorch), [Ttach](https://github.com/qubvel/ttach) (TTA library for PyTorch) and [Albumentations](https://github.com/albu/albumentations) (fast image augmentation library) - [here](https://github.com/catalyst-team/catalyst/blob/master/examples/notebooks/segmentation-tutorial.ipynb) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/catalyst-team/catalyst/blob/master/examples/notebooks/segmentation-tutorial.ipynb) + - Training SMP model with [Catalyst](https://github.com/catalyst-team/catalyst) (high-level framework for PyTorch), [TTAch](https://github.com/qubvel/ttach) (TTA library for PyTorch) and [Albumentations](https://github.com/albu/albumentations) (fast image augmentation library) - [here](https://github.com/catalyst-team/catalyst/blob/master/examples/notebooks/segmentation-tutorial.ipynb) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/catalyst-team/catalyst/blob/master/examples/notebooks/segmentation-tutorial.ipynb) -### Models +### 📦 Models #### Architectures - [Unet](https://arxiv.org/abs/1505.04597) and [Unet++](https://arxiv.org/pdf/1807.10165.pdf) @@ -72,17 +78,20 @@ preprocess_input = get_preprocessing_fn('resnet18', pretrained='imagenet') #### Encoders +
+Table with ALL avaliable encoders (click to expand) + |Encoder |Weights |Params, M | |--------------------------------|:------------------------------:|:------------------------------:| -|resnet18 |imagenet
ssl*
swsl* |11M | +|resnet18 |imagenet / ssl / swsl |11M | |resnet34 |imagenet |21M | -|resnet50 |imagenet
ssl*
swsl* |23M | +|resnet50 |imagenet / ssl / swsl |23M | |resnet101 |imagenet |42M | |resnet152 |imagenet |58M | -|resnext50_32x4d |imagenet
ssl*
swsl* |22M | -|resnext101_32x4d |ssl
swsl |42M | -|resnext101_32x8d |imagenet
instagram
ssl*
swsl*|86M | -|resnext101_32x16d |instagram
ssl*
swsl* |191M | +|resnext50_32x4d |imagenet / ssl / swsl |22M | +|resnext101_32x4d |ssl / swsl |42M | +|resnext101_32x8d |imagenet / instagram / ssl / swsl|86M | +|resnext101_32x16d |instagram / ssl / swsl |191M | |resnext101_32x32d |instagram |466M | |resnext101_32x48d |instagram |826M | |dpn68 |imagenet |11M | @@ -109,8 +118,8 @@ preprocess_input = get_preprocessing_fn('resnet18', pretrained='imagenet') |densenet169 |imagenet |12M | |densenet201 |imagenet |18M | |densenet161 |imagenet |26M | -|inceptionresnetv2 |imagenet
imagenet+background |54M | -|inceptionv4 |imagenet
imagenet+background |41M | +|inceptionresnetv2 |imagenet / imagenet+background |54M | +|inceptionv4 |imagenet / imagenet+background |41M | |efficientnet-b0 |imagenet |4M | |efficientnet-b1 |imagenet |6M | |efficientnet-b2 |imagenet |7M | @@ -121,15 +130,15 @@ preprocess_input = get_preprocessing_fn('resnet18', pretrained='imagenet') |efficientnet-b7 |imagenet |63M | |mobilenet_v2 |imagenet |2M | |xception |imagenet |22M | -|timm-efficientnet-b0 |imagenet
advprop
noisy-student|4M | -|timm-efficientnet-b1 |imagenet
advprop
noisy-student|6M | -|timm-efficientnet-b2 |imagenet
advprop
noisy-student|7M | -|timm-efficientnet-b3 |imagenet
advprop
noisy-student|10M | -|timm-efficientnet-b4 |imagenet
advprop
noisy-student|17M | -|timm-efficientnet-b5 |imagenet
advprop
noisy-student|28M | -|timm-efficientnet-b6 |imagenet
advprop
noisy-student|40M | -|timm-efficientnet-b7 |imagenet
advprop
noisy-student|63M | -|timm-efficientnet-b8 |imagenet
advprop |84M | +|timm-efficientnet-b0 |imagenet / advprop / noisy-student|4M | +|timm-efficientnet-b1 |imagenet / advprop / noisy-student|6M | +|timm-efficientnet-b2 |imagenet / advprop / noisy-student|7M | +|timm-efficientnet-b3 |imagenet / advprop / noisy-student|10M | +|timm-efficientnet-b4 |imagenet / advprop / noisy-student|17M | +|timm-efficientnet-b5 |imagenet / advprop / noisy-student|28M | +|timm-efficientnet-b6 |imagenet / advprop / noisy-student|40M | +|timm-efficientnet-b7 |imagenet / advprop / noisy-student|63M | +|timm-efficientnet-b8 |imagenet / advprop |84M | |timm-efficientnet-l2 |noisy-student |474M | |timm-resnest14d |imagenet |8M | |timm-resnest26d |imagenet |15M | @@ -176,7 +185,39 @@ preprocess_input = get_preprocessing_fn('resnet18', pretrained='imagenet') \* `ssl`, `wsl` - semi-supervised and weakly-supervised learning on ImageNet ([repo](https://github.com/facebookresearch/semi-supervised-ImageNet1K-models)). -### Models API +
+ +Just commonly used encoders + +|Encoder |Weights |Params, M | +|--------------------------------|:------------------------------:|:------------------------------:| +|resnet18 |imagenet / ssl / swsl |11M | +|resnet34 |imagenet |21M | +|resnet50 |imagenet / ssl / swsl |23M | +|resnet101 |imagenet |42M | +|resnext50_32x4d |imagenet / ssl / swsl |22M | +|resnext101_32x4d |ssl / swsl |42M | +|resnext101_32x8d |imagenet / instagram / ssl / swsl|86M | +|senet154 |imagenet |113M | +|se_resnext50_32x4d |imagenet |25M | +|se_resnext101_32x4d |imagenet |46M | +|densenet121 |imagenet |6M | +|densenet169 |imagenet |12M | +|densenet201 |imagenet |18M | +|inceptionresnetv2 |imagenet / imagenet+background |54M | +|inceptionv4 |imagenet / imagenet+background |41M | +|mobilenet_v2 |imagenet |2M | +|timm-efficientnet-b0 |imagenet / advprop / noisy-student|4M | +|timm-efficientnet-b1 |imagenet / advprop / noisy-student|6M | +|timm-efficientnet-b2 |imagenet / advprop / noisy-student|7M | +|timm-efficientnet-b3 |imagenet / advprop / noisy-student|10M | +|timm-efficientnet-b4 |imagenet / advprop / noisy-student|17M | +|timm-efficientnet-b5 |imagenet / advprop / noisy-student|28M | +|timm-efficientnet-b6 |imagenet / advprop / noisy-student|40M | +|timm-efficientnet-b7 |imagenet / advprop / noisy-student|63M | + + +### 🔁 Models API - `model.encoder` - pretrained backbone to extract features of different spatial resolution - `model.decoder` - depends on models architecture (`Unet`/`Linknet`/`PSPNet`/`FPN`) @@ -218,7 +259,7 @@ model = smp.Unet('resnet34', encoder_depth=4) ``` -### Installation +### 🛠 Installation PyPI version: ```bash $ pip install segmentation-models-pytorch @@ -228,12 +269,12 @@ Latest version from source: $ pip install git+https://github.com/qubvel/segmentation_models.pytorch ```` -### Competitions won with the library +### 🏆 Competitions won with the library `Segmentation Models` package is widely used in the image segmentation competitions. [Here](https://github.com/qubvel/segmentation_models.pytorch/blob/master/HALLOFFAME.md) you can find competitions, names of the winners and links to their solutions. -### Contributing +### 🤝 Contributing ##### Run test ```bash @@ -244,7 +285,7 @@ $ docker build -f docker/Dockerfile.dev -t smp:dev . && docker run --rm smp:dev $ docker build -f docker/Dockerfile.dev -t smp:dev . && docker run --rm smp:dev python misc/generate_table.py ``` -### Citing +### 📝 Citing ``` @misc{Yakubovskiy:2019, Author = {Pavel Yakubovskiy}, @@ -256,5 +297,5 @@ $ docker build -f docker/Dockerfile.dev -t smp:dev . && docker run --rm smp:dev } ``` -### License +### 🛡️ License Project is distributed under [MIT License](https://github.com/qubvel/segmentation_models.pytorch/blob/master/LICENSE)