Skip to content

Commit d9a6950

Browse files
authored
Update galleries to use Multi-weight idioms (#6030)
* Update the preprocessing decription for RAFT. * Fixing incorrect usage of models. * Fixing the content of viz utils * Addressing review comments
1 parent 10acc82 commit d9a6950

File tree

2 files changed

+26
-60
lines changed

2 files changed

+26
-60
lines changed

gallery/plot_optical_flow.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -81,11 +81,11 @@ def plot(imgs, **imshow_kwargs):
8181
plot(img1_batch)
8282

8383
#########################
84-
# The RAFT model that we will use accepts RGB float images with pixel values in
85-
# [-1, 1]. The frames we got from :func:`~torchvision.io.read_video` are int
86-
# images with values in [0, 255], so we will have to pre-process them. We also
87-
# reduce the image sizes for the example to run faster. Image dimension must be
88-
# divisible by 8.
84+
# The RAFT model accepts RGB images. We first get the frames from
85+
# :func:`~torchvision.io.read_video` and resize them to ensure their
86+
# dimensions are divisible by 8. Then we use the transforms bundled into the
87+
# weights in order to preprocess the input and rescale its values to the
88+
# required ``[-1, 1]`` interval.
8989

9090
from torchvision.models.optical_flow import Raft_Large_Weights
9191

gallery/plot_visualization_utils.py

Lines changed: 21 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,9 @@ def show(imgs):
4343

4444
dog1_int = read_image(str(Path('assets') / 'dog1.jpg'))
4545
dog2_int = read_image(str(Path('assets') / 'dog2.jpg'))
46+
dog_list = [dog1_int, dog2_int]
4647

47-
grid = make_grid([dog1_int, dog2_int, dog1_int, dog2_int])
48+
grid = make_grid(dog_list)
4849
show(grid)
4950

5051
####################################
@@ -65,28 +66,23 @@ def show(imgs):
6566

6667
#####################################
6768
# Naturally, we can also plot bounding boxes produced by torchvision detection
68-
# models. Here is demo with a Faster R-CNN model loaded from
69+
# models. Here is a demo with a Faster R-CNN model loaded from
6970
# :func:`~torchvision.models.detection.fasterrcnn_resnet50_fpn`
70-
# model. You can also try using a RetinaNet with
71-
# :func:`~torchvision.models.detection.retinanet_resnet50_fpn`, an SSDlite with
72-
# :func:`~torchvision.models.detection.ssdlite320_mobilenet_v3_large` or an SSD with
73-
# :func:`~torchvision.models.detection.ssd300_vgg16`. For more details
74-
# on the output of such models, you may refer to :ref:`instance_seg_output`.
71+
# model. For more details on the output of such models, you may
72+
# refer to :ref:`instance_seg_output`.
7573

7674
from torchvision.models.detection import fasterrcnn_resnet50_fpn, FasterRCNN_ResNet50_FPN_Weights
7775

7876

79-
batch_int = torch.stack([dog1_int, dog2_int])
80-
8177
weights = FasterRCNN_ResNet50_FPN_Weights.DEFAULT
8278
transforms = weights.transforms()
8379

84-
batch = transforms(batch_int)
80+
images = [transforms(d) for d in dog_list]
8581

8682
model = fasterrcnn_resnet50_fpn(weights=weights, progress=False)
8783
model = model.eval()
8884

89-
outputs = model(batch)
85+
outputs = model(images)
9086
print(outputs)
9187

9288
#####################################
@@ -96,7 +92,7 @@ def show(imgs):
9692
score_threshold = .8
9793
dogs_with_boxes = [
9894
draw_bounding_boxes(dog_int, boxes=output['boxes'][output['scores'] > score_threshold], width=4)
99-
for dog_int, output in zip(batch_int, outputs)
95+
for dog_int, output in zip(dog_list, outputs)
10096
]
10197
show(dogs_with_boxes)
10298

@@ -114,14 +110,8 @@ def show(imgs):
114110
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
115111
#
116112
# We will see how to use it with torchvision's FCN Resnet-50, loaded with
117-
# :func:`~torchvision.models.segmentation.fcn_resnet50`. You can also try using
118-
# DeepLabv3 (:func:`~torchvision.models.segmentation.deeplabv3_resnet50`) or
119-
# lraspp mobilenet models
120-
# (:func:`~torchvision.models.segmentation.lraspp_mobilenet_v3_large`).
121-
#
122-
# Let's start by looking at the output of the model. Remember that in general,
123-
# images must be normalized before they're passed to a semantic segmentation
124-
# model.
113+
# :func:`~torchvision.models.segmentation.fcn_resnet50`. Let's start by looking
114+
# at the output of the model.
125115

126116
from torchvision.models.segmentation import fcn_resnet50, FCN_ResNet50_Weights
127117

@@ -131,8 +121,8 @@ def show(imgs):
131121
model = fcn_resnet50(weights=weights, progress=False)
132122
model = model.eval()
133123

134-
normalized_batch = transforms(batch)
135-
output = model(normalized_batch)['out']
124+
batch = torch.stack([transforms(d) for d in dog_list])
125+
output = model(batch)['out']
136126
print(output.shape, output.min().item(), output.max().item())
137127

138128
#####################################
@@ -145,18 +135,13 @@ def show(imgs):
145135
# Let's plot the masks that have been detected for the dog class and for the
146136
# boat class:
147137

148-
sem_classes = [
149-
'__background__', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus',
150-
'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike',
151-
'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor'
152-
]
153-
sem_class_to_idx = {cls: idx for (idx, cls) in enumerate(sem_classes)}
138+
sem_class_to_idx = {cls: idx for (idx, cls) in enumerate(weights.meta["categories"])}
154139

155140
normalized_masks = torch.nn.functional.softmax(output, dim=1)
156141

157142
dog_and_boat_masks = [
158143
normalized_masks[img_idx, sem_class_to_idx[cls]]
159-
for img_idx in range(batch.shape[0])
144+
for img_idx in range(len(dog_list))
160145
for cls in ('dog', 'boat')
161146
]
162147

@@ -195,7 +180,7 @@ def show(imgs):
195180

196181
dogs_with_masks = [
197182
draw_segmentation_masks(img, masks=mask, alpha=0.7)
198-
for img, mask in zip(batch_int, boolean_dog_masks)
183+
for img, mask in zip(dog_list, boolean_dog_masks)
199184
]
200185
show(dogs_with_masks)
201186

@@ -241,7 +226,7 @@ def show(imgs):
241226

242227
dogs_with_masks = [
243228
draw_segmentation_masks(img, masks=mask, alpha=.6)
244-
for img, mask in zip(batch_int, all_classes_masks)
229+
for img, mask in zip(dog_list, all_classes_masks)
245230
]
246231
show(dogs_with_masks)
247232

@@ -272,12 +257,12 @@ def show(imgs):
272257
weights = MaskRCNN_ResNet50_FPN_Weights.DEFAULT
273258
transforms = weights.transforms()
274259

275-
batch = transforms(batch_int)
260+
images = [transforms(d) for d in dog_list]
276261

277262
model = maskrcnn_resnet50_fpn(weights=weights, progress=False)
278263
model = model.eval()
279264

280-
output = model(batch)
265+
output = model(images)
281266
print(output)
282267

283268
#####################################
@@ -304,30 +289,13 @@ def show(imgs):
304289
f"min = {dog1_masks.min()}, max = {dog1_masks.max()}")
305290

306291
#####################################
307-
# Here the masks corresponds to probabilities indicating, for each pixel, how
292+
# Here the masks correspond to probabilities indicating, for each pixel, how
308293
# likely it is to belong to the predicted label of that instance. Those
309294
# predicted labels correspond to the 'labels' element in the same output dict.
310295
# Let's see which labels were predicted for the instances of the first image.
311296

312-
inst_classes = [
313-
'__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
314-
'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign',
315-
'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
316-
'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A', 'N/A',
317-
'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
318-
'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
319-
'bottle', 'N/A', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
320-
'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
321-
'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table',
322-
'N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
323-
'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 'book',
324-
'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'
325-
]
326-
327-
inst_class_to_idx = {cls: idx for (idx, cls) in enumerate(inst_classes)}
328-
329297
print("For the first dog, the following instances were detected:")
330-
print([inst_classes[label] for label in dog1_output['labels']])
298+
print([weights.meta["categories"][label] for label in dog1_output['labels']])
331299

332300
#####################################
333301
# Interestingly, the model detects two persons in the image. Let's go ahead and
@@ -369,7 +337,7 @@ def show(imgs):
369337

370338
dogs_with_masks = [
371339
draw_segmentation_masks(img, mask.squeeze(1))
372-
for img, mask in zip(batch_int, boolean_masks)
340+
for img, mask in zip(dog_list, boolean_masks)
373341
]
374342
show(dogs_with_masks)
375343

@@ -388,8 +356,6 @@ def show(imgs):
388356
# torchvision's KeypointRCNN loaded with :func:`~torchvision.models.detection.keypointrcnn_resnet50_fpn`.
389357
# We will first have a look at output of the model.
390358
#
391-
# Note that the keypoint detection model does not need normalized images.
392-
#
393359

394360
from torchvision.models.detection import keypointrcnn_resnet50_fpn, KeypointRCNN_ResNet50_FPN_Weights
395361
from torchvision.io import read_image

0 commit comments

Comments
 (0)