Skip to content

Commit 300ef76

Browse files
authored
Refactor resnet_fpn_backbone according to docstring conventions (#2482)
1 parent 9bd25d0 commit 300ef76

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

torchvision/models/detection/backbone_utils.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,6 @@ def forward(self, x):
4242

4343

4444
def resnet_fpn_backbone(backbone_name, pretrained, norm_layer=misc_nn_ops.FrozenBatchNorm2d, trainable_layers=3):
45-
backbone = resnet.__dict__[backbone_name](
46-
pretrained=pretrained,
47-
norm_layer=norm_layer)
4845
"""
4946
Constructs a specified ResNet backbone with FPN on top. Freezes the specified number of layers in the backbone.
5047
@@ -73,6 +70,10 @@ def resnet_fpn_backbone(backbone_name, pretrained, norm_layer=misc_nn_ops.Frozen
7370
trainable_layers (int): number of trainable (not frozen) resnet layers starting from final block.
7471
Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable.
7572
"""
73+
backbone = resnet.__dict__[backbone_name](
74+
pretrained=pretrained,
75+
norm_layer=norm_layer)
76+
7677
# select layers that wont be frozen
7778
assert trainable_layers <= 5 and trainable_layers >= 0
7879
layers_to_train = ['layer4', 'layer3', 'layer2', 'layer1', 'conv1'][:trainable_layers]

0 commit comments

Comments
 (0)