Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added gallery/assets/person1.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file removed gallery/assets/visualization_utils_thumbnail.png
Binary file not shown.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
111 changes: 109 additions & 2 deletions gallery/plot_visualization_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
=======================

This example illustrates some of the utilities that torchvision offers for
visualizing images, bounding boxes, and segmentation masks.
visualizing images, bounding boxes, segmentation masks and keypoints.
"""

# sphinx_gallery_thumbnail_path = "../../gallery/assets/visualization_utils_thumbnail.png"
# sphinx_gallery_thumbnail_path = "../../gallery/assets/visualization_utils_thumbnail2.png"

import torch
import numpy as np
Expand Down Expand Up @@ -366,3 +366,110 @@ def show(imgs):
# The two 'people' masks in the first image where not selected because they have
# a lower score than the score threshold. Similarly in the second image, the
# instance with class 15 (which corresponds to 'bench') was not selected.

#####################################
# Visualizing keypoints
# ------------------------------
# The :func:`~torchvision.utils.draw_keypoints` function can be used to
# draw keypoints on images. We will see how to use it with
# torchvision's KeypointRCNN loaded with :func:`~torchvision.models.detection.keypointrcnn_resnet50_fpn`.
# We will first have a look at output of the model.
#
# Note that the keypoint detection model does not need normalized images.
#

from torchvision.models.detection import keypointrcnn_resnet50_fpn
from torchvision.io import read_image

person_int = read_image(str(Path("assets") / "person1.jpg"))
person_float = convert_image_dtype(person_int, dtype=torch.float)

model = keypointrcnn_resnet50_fpn(pretrained=True, progress=False)
model = model.eval()

outputs = model([person_float])
print(outputs)

#####################################
# As we see the output contains a list of dictionaries.
# The output list is of length batch_size.
# We currently have just a single image so length of list is 1.
# Each entry in the list corresponds to an input image,
# and it is a dict with keys `boxes`, `labels`, `scores`, `keypoints` and `keypoint_scores`.
# Each value associated to those keys has `num_instances` elements in it.
# In our case above there are 2 instances detected in the image.

kpts = outputs[0]['keypoints']
scores = outputs[0]['scores']

print(kpts)
print(scores)

#####################################
# The KeypointRCNN model detects there are two instances in the image.
# If you plot the boxes by using :func:`~draw_bounding_boxes`
# you would recognize they are the person and the surfboard.
# If we look at the scores, we will realize that the model is much more confident about the person than surfboard.
# We could now set a threshold confidence and plot instances which we are confident enough.
# Let us set a threshold of 0.75 and filter out the keypoints corresponding to the person.

detect_threshold = 0.75
idx = torch.where(scores > detect_threshold)
keypoints = kpts[idx]

print(keypoints)

#####################################
# Great, now we have the keypoints corresponding to the person.
# Each keypoint is represented by x, y coordinates and the visibility.
# We can now use the :func:`~torchvision.utils.draw_keypoints` function to draw keypoints.
# Note that the utility expects uint8 images.

from torchvision.utils import draw_keypoints

res = draw_keypoints(person_int, keypoints, colors="blue", radius=3)
show(res)

#####################################
# As we see the keypoints appear as colored circles over the image.
# The coco keypoints for a person are ordered and represent the following list.\

coco_keypoints = [
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this will be great when people can use metadata from the new keypoint rcnn weights.
Right now we have to explicitly remember this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Although it isn't used in the code. I made this as a code block so that it is copy-pastable.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes agreed. We've already added it on the new API:

_common_meta = {"categories": _COCO_PERSON_CATEGORIES, "keypoint_names": _COCO_PERSON_KEYPOINT_NAMES}

_COCO_PERSON_CATEGORIES = ["no person", "person"]
_COCO_PERSON_KEYPOINT_NAMES = [
"nose",
"left_eye",

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess once the new prototype models and transforms are moved to stable, we would need to revisit the gallery examples.

"nose", "left_eye", "right_eye", "left_ear", "right_ear",
"left_shoulder", "right_shoulder", "left_elbow", "right_elbow",
"left_wrist", "right_wrist", "left_hip", "right_hip",
"left_knee", "right_knee", "left_ankle", "right_ankle",
]

#####################################
# What if we are interested in joining the keypoints?
# This is especially useful in creating pose detection or action recognition.
# We can join the keypoints easily using the `connectivity` parameter.
# A close observation would reveal that we would need to join the points in below
# order to construct human skeleton.
#
# nose -> left_eye -> left_ear. (0, 1), (1, 3)
#
# nose -> right_eye -> right_ear. (0, 2), (2, 4)
#
# nose -> left_shoulder -> left_elbow -> left_wrist. (0, 5), (5, 7), (7, 9)
#
# nose -> right_shoulder -> right_elbow -> right_wrist. (0, 6), (6, 8), (8, 10)
#
# left_shoulder -> left_hip -> left_knee -> left_ankle. (5, 11), (11, 13), (13, 15)
#
# right_shoulder -> right_hip -> right_knee -> right_ankle. (6, 12), (12, 14), (14, 16)
#
# We will create a list containing these keypoint ids to be connected.

connect_skeleton = [
(0, 1), (0, 2), (1, 3), (2, 4), (0, 5), (0, 6), (5, 7), (6, 8),
(7, 9), (8, 10), (5, 11), (6, 12), (11, 13), (12, 14), (13, 15), (14, 16)
]

#####################################
# We pass the above list to the connectivity parameter to connect the keypoints.
#

res = draw_keypoints(person_int, keypoints, connectivity=connect_skeleton, colors="blue", radius=4, width=3)
show(res)
18 changes: 16 additions & 2 deletions test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,14 @@ def test_draw_keypoints_vanilla():

img = torch.full((3, 100, 100), 0, dtype=torch.uint8)
img_cp = img.clone()
result = utils.draw_keypoints(img, keypoints, colors="red", connectivity=((0, 1),))
result = utils.draw_keypoints(
img,
keypoints,
colors="red",
connectivity=[
(0, 1),
],
)
path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "fakedata", "draw_keypoint_vanilla.png")
if not os.path.exists(path):
res = Image.fromarray(result.permute(1, 2, 0).contiguous().numpy())
Expand All @@ -277,7 +284,14 @@ def test_draw_keypoints_colored(colors):

img = torch.full((3, 100, 100), 0, dtype=torch.uint8)
img_cp = img.clone()
result = utils.draw_keypoints(img, keypoints, colors=colors, connectivity=((0, 1),))
result = utils.draw_keypoints(
img,
keypoints,
colors=colors,
connectivity=[
(0, 1),
],
)
assert result.size(0) == 3
assert_equal(keypoints, keypoints_cp)
assert_equal(img, img_cp)
Expand Down
4 changes: 2 additions & 2 deletions torchvision/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ def draw_segmentation_masks(
def draw_keypoints(
image: torch.Tensor,
keypoints: torch.Tensor,
connectivity: Optional[Tuple[Tuple[int, int]]] = None,
connectivity: Optional[List[Tuple[int, int]]] = None,
colors: Optional[Union[str, Tuple[int, int, int]]] = None,
radius: int = 2,
width: int = 3,
Expand All @@ -318,7 +318,7 @@ def draw_keypoints(
image (Tensor): Tensor of shape (3, H, W) and dtype uint8.
keypoints (Tensor): Tensor of shape (num_instances, K, 2) the K keypoints location for each of the N instances,
in the format [x, y].
connectivity (Tuple[Tuple[int, int]]]): A Tuple of tuple where,
connectivity (List[Tuple[int, int]]]): A List of tuple where,
each tuple contains pair of keypoints to be connected.
colors (str, Tuple): The color can be represented as
PIL strings e.g. "red" or "#FF00FF", or as RGB tuples e.g. ``(240, 10, 157)``.
Expand Down