diff --git a/test/test_utils.py b/test/test_utils.py index 727208ec16c..fab8c5fc082 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -176,6 +176,15 @@ def test_draw_boxes_warning(): utils.draw_bounding_boxes(img, boxes, font_size=11) +def test_draw_no_boxes(): + img = torch.full((3, 100, 100), 0, dtype=torch.uint8) + boxes = torch.full((0, 4), 0, dtype=torch.float) + with pytest.warns(UserWarning, match=re.escape("boxes doesn't contain any box. No box was drawn")): + res = utils.draw_bounding_boxes(img, boxes) + # Check that the function didnt change the image + assert res.eq(img).all() + + @pytest.mark.parametrize( "colors", [ @@ -266,6 +275,15 @@ def test_draw_segmentation_masks_errors(): utils.draw_segmentation_masks(image=img, masks=masks, colors=bad_colors) +def test_draw_no_segmention_mask(): + img = torch.full((3, 100, 100), 0, dtype=torch.uint8) + masks = torch.full((0, 100, 100), 0, dtype=torch.bool) + with pytest.warns(UserWarning, match=re.escape("masks doesn't contain any mask. No mask was drawn")): + res = utils.draw_segmentation_masks(img, masks) + # Check that the function didnt change the image + assert res.eq(img).all() + + def test_draw_keypoints_vanilla(): # Keypoints is declared on top as global variable keypoints_cp = keypoints.clone() diff --git a/torchvision/utils.py b/torchvision/utils.py index afbc1332105..80627ca5aa2 100644 --- a/torchvision/utils.py +++ b/torchvision/utils.py @@ -211,6 +211,10 @@ def draw_bounding_boxes( num_boxes = boxes.shape[0] + if num_boxes == 0: + warnings.warn("boxes doesn't contain any box. No box was drawn") + return image + if labels is None: labels: Union[List[str], List[None]] = [None] * num_boxes # type: ignore[no-redef] elif len(labels) != num_boxes: @@ -311,6 +315,10 @@ def draw_segmentation_masks( if colors is not None and num_masks > len(colors): raise ValueError(f"There are more masks ({num_masks}) than colors ({len(colors)})") + if num_masks == 0: + warnings.warn("masks doesn't contain any mask. No mask was drawn") + return image + if colors is None: colors = _generate_color_palette(num_masks)