Skip to content

Commit 98df9a8

Browse files
committed
Adding weights_backbone support in detection and segmentation
1 parent fc42bf0 commit 98df9a8

File tree

4 files changed

+21
-17
lines changed

4 files changed

+21
-17
lines changed

references/detection/README.md

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,65 +24,65 @@ Except otherwise noted, all models have been trained on 8x V100 GPUs.
2424
```
2525
torchrun --nproc_per_node=8 train.py\
2626
--dataset coco --model fasterrcnn_resnet50_fpn --epochs 26\
27-
--lr-steps 16 22 --aspect-ratio-group-factor 3
27+
--lr-steps 16 22 --aspect-ratio-group-factor 3 --weights-backbone ResNet50_Weights.IMAGENET1K_V1
2828
```
2929

3030
### Faster R-CNN MobileNetV3-Large FPN
3131
```
3232
torchrun --nproc_per_node=8 train.py\
3333
--dataset coco --model fasterrcnn_mobilenet_v3_large_fpn --epochs 26\
34-
--lr-steps 16 22 --aspect-ratio-group-factor 3
34+
--lr-steps 16 22 --aspect-ratio-group-factor 3 --weights-backbone MobileNet_V3_Large_Weights.IMAGENET1K_V1
3535
```
3636

3737
### Faster R-CNN MobileNetV3-Large 320 FPN
3838
```
3939
torchrun --nproc_per_node=8 train.py\
4040
--dataset coco --model fasterrcnn_mobilenet_v3_large_320_fpn --epochs 26\
41-
--lr-steps 16 22 --aspect-ratio-group-factor 3
41+
--lr-steps 16 22 --aspect-ratio-group-factor 3 --weights-backbone MobileNet_V3_Large_Weights.IMAGENET1K_V1
4242
```
4343

4444
### FCOS ResNet-50 FPN
4545
```
4646
torchrun --nproc_per_node=8 train.py\
4747
--dataset coco --model fcos_resnet50_fpn --epochs 26\
48-
--lr-steps 16 22 --aspect-ratio-group-factor 3 --lr 0.01 --amp
48+
--lr-steps 16 22 --aspect-ratio-group-factor 3 --lr 0.01 --amp --weights-backbone ResNet50_Weights.IMAGENET1K_V1
4949
```
5050

5151
### RetinaNet
5252
```
5353
torchrun --nproc_per_node=8 train.py\
5454
--dataset coco --model retinanet_resnet50_fpn --epochs 26\
55-
--lr-steps 16 22 --aspect-ratio-group-factor 3 --lr 0.01
55+
--lr-steps 16 22 --aspect-ratio-group-factor 3 --lr 0.01 --weights-backbone ResNet50_Weights.IMAGENET1K_V1
5656
```
5757

5858
### SSD300 VGG16
5959
```
6060
torchrun --nproc_per_node=8 train.py\
6161
--dataset coco --model ssd300_vgg16 --epochs 120\
6262
--lr-steps 80 110 --aspect-ratio-group-factor 3 --lr 0.002 --batch-size 4\
63-
--weight-decay 0.0005 --data-augmentation ssd
63+
--weight-decay 0.0005 --data-augmentation ssd --weights-backbone VGG16_Weights.IMAGENET1K_FEATURES
6464
```
6565

6666
### SSDlite320 MobileNetV3-Large
6767
```
6868
torchrun --nproc_per_node=8 train.py\
6969
--dataset coco --model ssdlite320_mobilenet_v3_large --epochs 660\
7070
--aspect-ratio-group-factor 3 --lr-scheduler cosineannealinglr --lr 0.15 --batch-size 24\
71-
--weight-decay 0.00004 --data-augmentation ssdlite
71+
--weight-decay 0.00004 --data-augmentation ssdlite --weights-backbone MobileNet_V3_Large_Weights.IMAGENET1K_V1
7272
```
7373

7474

7575
### Mask R-CNN
7676
```
7777
torchrun --nproc_per_node=8 train.py\
7878
--dataset coco --model maskrcnn_resnet50_fpn --epochs 26\
79-
--lr-steps 16 22 --aspect-ratio-group-factor 3
79+
--lr-steps 16 22 --aspect-ratio-group-factor 3 --weights-backbone ResNet50_Weights.IMAGENET1K_V1
8080
```
8181

8282

8383
### Keypoint R-CNN
8484
```
8585
torchrun --nproc_per_node=8 train.py\
8686
--dataset coco_kp --model keypointrcnn_resnet50_fpn --epochs 46\
87-
--lr-steps 36 43 --aspect-ratio-group-factor 3
87+
--lr-steps 36 43 --aspect-ratio-group-factor 3 --weights-backbone ResNet50_Weights.IMAGENET1K_V1
8888
```

references/detection/train.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@ def get_args_parser(add_help=True):
129129
parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes")
130130
parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training")
131131
parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")
132+
parser.add_argument("--weights-backbone", default=None, type=str, help="the backbone weights enum name to load")
132133

133134
# Mixed precision training parameters
134135
parser.add_argument("--amp", action="store_true", help="Use torch.cuda.amp for mixed precision training")
@@ -178,7 +179,9 @@ def main(args):
178179
if "rcnn" in args.model:
179180
if args.rpn_score_thresh is not None:
180181
kwargs["rpn_score_thresh"] = args.rpn_score_thresh
181-
model = torchvision.models.detection.__dict__[args.model](weights=args.weights, num_classes=num_classes, **kwargs)
182+
model = torchvision.models.detection.__dict__[args.model](
183+
weights=args.weights, weights_backbone=args.weights_backbone, num_classes=num_classes, **kwargs
184+
)
182185
model.to(device)
183186
if args.distributed and args.sync_bn:
184187
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)

references/segmentation/README.md

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,30 +14,30 @@ You must modify the following flags:
1414

1515
## fcn_resnet50
1616
```
17-
torchrun --nproc_per_node=8 train.py --lr 0.02 --dataset coco -b 4 --model fcn_resnet50 --aux-loss
17+
torchrun --nproc_per_node=8 train.py --lr 0.02 --dataset coco -b 4 --model fcn_resnet50 --aux-loss --weights-backbone ResNet50_Weights.IMAGENET1K_V1
1818
```
1919

2020
## fcn_resnet101
2121
```
22-
torchrun --nproc_per_node=8 train.py --lr 0.02 --dataset coco -b 4 --model fcn_resnet101 --aux-loss
22+
torchrun --nproc_per_node=8 train.py --lr 0.02 --dataset coco -b 4 --model fcn_resnet101 --aux-loss --weights-backbone ResNet101_Weights.IMAGENET1K_V1
2323
```
2424

2525
## deeplabv3_resnet50
2626
```
27-
torchrun --nproc_per_node=8 train.py --lr 0.02 --dataset coco -b 4 --model deeplabv3_resnet50 --aux-loss
27+
torchrun --nproc_per_node=8 train.py --lr 0.02 --dataset coco -b 4 --model deeplabv3_resnet50 --aux-loss --weights-backbone ResNet50_Weights.IMAGENET1K_V1
2828
```
2929

3030
## deeplabv3_resnet101
3131
```
32-
torchrun --nproc_per_node=8 train.py --lr 0.02 --dataset coco -b 4 --model deeplabv3_resnet101 --aux-loss
32+
torchrun --nproc_per_node=8 train.py --lr 0.02 --dataset coco -b 4 --model deeplabv3_resnet101 --aux-loss --weights-backbone ResNet101_Weights.IMAGENET1K_V1
3333
```
3434

3535
## deeplabv3_mobilenet_v3_large
3636
```
37-
torchrun --nproc_per_node=8 train.py --dataset coco -b 4 --model deeplabv3_mobilenet_v3_large --aux-loss --wd 0.000001
37+
torchrun --nproc_per_node=8 train.py --dataset coco -b 4 --model deeplabv3_mobilenet_v3_large --aux-loss --wd 0.000001 --weights-backbone MobileNet_V3_Large_Weights.IMAGENET1K_V1
3838
```
3939

4040
## lraspp_mobilenet_v3_large
4141
```
42-
torchrun --nproc_per_node=8 train.py --dataset coco -b 4 --model lraspp_mobilenet_v3_large --wd 0.000001
42+
torchrun --nproc_per_node=8 train.py --dataset coco -b 4 --model lraspp_mobilenet_v3_large --wd 0.000001 --weights-backbone MobileNet_V3_Large_Weights.IMAGENET1K_V1
4343
```

references/segmentation/train.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def main(args):
124124
)
125125

126126
model = torchvision.models.segmentation.__dict__[args.model](
127-
weights=args.weights, num_classes=num_classes, aux_loss=args.aux_loss
127+
weights=args.weights, weights_backbone=args.weights_backbone, num_classes=num_classes, aux_loss=args.aux_loss
128128
)
129129
model.to(device)
130130
if args.distributed:
@@ -258,6 +258,7 @@ def get_args_parser(add_help=True):
258258
parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training")
259259

260260
parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")
261+
parser.add_argument("--weights-backbone", default=None, type=str, help="the backbone weights enum name to load")
261262

262263
# Mixed precision training parameters
263264
parser.add_argument("--amp", action="store_true", help="Use torch.cuda.amp for mixed precision training")

0 commit comments

Comments
 (0)