Skip to content
4 changes: 0 additions & 4 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,6 @@ ignore_errors=True

ignore_errors = True

[mypy-torchvision.models.detection.backbone_utils]

ignore_errors = True

[mypy-torchvision.models.detection.transform]

ignore_errors = True
Expand Down
62 changes: 38 additions & 24 deletions torchvision/models/detection/backbone_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import warnings
from typing import List, Optional
from typing import Callable, Dict, Optional, List

from torch import nn
from torch import nn, Tensor
from torchvision.ops import misc as misc_nn_ops
from torchvision.ops.feature_pyramid_network import FeaturePyramidNetwork, LastLevelMaxPool, ExtraFPNBlock

Expand Down Expand Up @@ -29,7 +29,14 @@ class BackboneWithFPN(nn.Module):
out_channels (int): the number of channels in the FPN
"""

def __init__(self, backbone, return_layers, in_channels_list, out_channels, extra_blocks=None):
def __init__(
self,
backbone: nn.Module,
return_layers: Dict[str, str],
in_channels_list: List[int],
out_channels: int,
extra_blocks: Optional[ExtraFPNBlock] = None,
) -> None:
super(BackboneWithFPN, self).__init__()

if extra_blocks is None:
Expand All @@ -43,20 +50,20 @@ def __init__(self, backbone, return_layers, in_channels_list, out_channels, extr
)
self.out_channels = out_channels

def forward(self, x):
def forward(self, x: Tensor) -> Dict[str, Tensor]:
x = self.body(x)
x = self.fpn(x)
return x


def resnet_fpn_backbone(
backbone_name,
pretrained,
norm_layer=misc_nn_ops.FrozenBatchNorm2d,
trainable_layers=3,
returned_layers=None,
extra_blocks=None,
):
backbone_name: str,
pretrained: bool,
norm_layer: Callable[..., nn.Module] = misc_nn_ops.FrozenBatchNorm2d,
trainable_layers: int = 3,
returned_layers: Optional[List[int]] = None,
extra_blocks: Optional[ExtraFPNBlock] = None,
) -> BackboneWithFPN:
"""
Constructs a specified ResNet backbone with FPN on top. Freezes the specified number of layers in the backbone.

Expand All @@ -80,7 +87,7 @@ def resnet_fpn_backbone(
backbone_name (string): resnet architecture. Possible values are 'ResNet', 'resnet18', 'resnet34', 'resnet50',
'resnet101', 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 'wide_resnet50_2', 'wide_resnet101_2'
pretrained (bool): If True, returns a model with backbone pre-trained on Imagenet
norm_layer (torchvision.ops): it is recommended to use the default value. For details visit:
norm_layer (callable): it is recommended to use the default value. For details visit:
(https://github.com/facebookresearch/maskrcnn-benchmark/issues/267)
trainable_layers (int): number of trainable (not frozen) resnet layers starting from final block.
Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable.
Expand All @@ -101,7 +108,8 @@ def _resnet_backbone_config(
trainable_layers: int,
returned_layers: Optional[List[int]],
extra_blocks: Optional[ExtraFPNBlock],
):
) -> BackboneWithFPN:

# select layers that wont be frozen
assert 0 <= trainable_layers <= 5
layers_to_train = ["layer4", "layer3", "layer2", "layer1", "conv1"][:trainable_layers]
Expand All @@ -125,8 +133,13 @@ def _resnet_backbone_config(
return BackboneWithFPN(backbone, return_layers, in_channels_list, out_channels, extra_blocks=extra_blocks)


def _validate_trainable_layers(pretrained, trainable_backbone_layers, max_value, default_value):
# dont freeze any layers if pretrained model or backbone is not used
def _validate_trainable_layers(
pretrained: bool,
trainable_backbone_layers: Optional[int],
max_value: int,
default_value: int,
) -> int:
# don't freeze any layers if pretrained model or backbone is not used
if not pretrained:
if trainable_backbone_layers is not None:
warnings.warn(
Expand All @@ -144,14 +157,15 @@ def _validate_trainable_layers(pretrained, trainable_backbone_layers, max_value,


def mobilenet_backbone(
backbone_name,
pretrained,
fpn,
norm_layer=misc_nn_ops.FrozenBatchNorm2d,
trainable_layers=2,
returned_layers=None,
extra_blocks=None,
):
backbone_name: str,
pretrained: bool,
fpn: bool,
norm_layer: Callable[..., nn.Module] = misc_nn_ops.FrozenBatchNorm2d,
trainable_layers: int = 2,
returned_layers: Optional[List[int]] = None,
extra_blocks: Optional[ExtraFPNBlock] = None,
) -> nn.Module:

backbone = mobilenet.__dict__[backbone_name](pretrained=pretrained, norm_layer=norm_layer).features

# Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks.
Expand Down Expand Up @@ -185,5 +199,5 @@ def mobilenet_backbone(
# depthwise linear combination of channels to reduce their size
nn.Conv2d(backbone[-1].out_channels, out_channels, 1),
)
m.out_channels = out_channels
m.out_channels = out_channels # type: ignore[assignment]
return m
22 changes: 14 additions & 8 deletions torchvision/prototype/models/detection/backbone_utils.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,20 @@
from ....models.detection.backbone_utils import misc_nn_ops, _resnet_backbone_config
from typing import Callable, Optional, List

from torch import nn

from ....models.detection.backbone_utils import misc_nn_ops, _resnet_backbone_config, BackboneWithFPN, ExtraFPNBlock
from .. import resnet
from .._api import Weights


def resnet_fpn_backbone(
backbone_name,
weights,
norm_layer=misc_nn_ops.FrozenBatchNorm2d,
trainable_layers=3,
returned_layers=None,
extra_blocks=None,
):
backbone_name: str,
weights: Optional[Weights],
norm_layer: Callable[..., nn.Module] = misc_nn_ops.FrozenBatchNorm2d,
trainable_layers: int = 3,
returned_layers: Optional[List[int]] = None,
extra_blocks: Optional[ExtraFPNBlock] = None,
) -> BackboneWithFPN:

backbone = resnet.__dict__[backbone_name](weights=weights, norm_layer=norm_layer)
return _resnet_backbone_config(backbone, trainable_layers, returned_layers, extra_blocks)