From b00b5225e29faea4f236cc51091bc5b6cc4cbb07 Mon Sep 17 00:00:00 2001 From: oxabz Date: Thu, 21 Apr 2022 16:29:39 +0200 Subject: [PATCH 01/10] Fixing the IndexError in draw_segmentation_masks --- torchvision/utils.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/torchvision/utils.py b/torchvision/utils.py index e82752ab28b..ce5a00f3765 100644 --- a/torchvision/utils.py +++ b/torchvision/utils.py @@ -206,6 +206,10 @@ def draw_bounding_boxes( num_boxes = boxes.shape[0] + if num_boxes == 0: + warnings.warn(f"boxes doesn't contain any box. No box was drawn") + return image.to(out_dtype) + if labels is None: labels: Union[List[str], List[None]] = [None] * num_boxes # type: ignore[no-redef] elif len(labels) != num_boxes: @@ -306,6 +310,10 @@ def draw_segmentation_masks( if colors is not None and num_masks > len(colors): raise ValueError(f"There are more masks ({num_masks}) than colors ({len(colors)})") + if num_masks == 0: + warnings.warn(f"masks doesn't contain any mask. No mask was drawn") + return image.to(out_dtype) + if colors is None: colors = _generate_color_palette(num_masks) From 6e695ad9c0a87e30050a86900a5807ecd3dcc92b Mon Sep 17 00:00:00 2001 From: LEGRAND Matthieu Date: Thu, 21 Apr 2022 19:04:50 +0200 Subject: [PATCH 02/10] fixing the bug on draw_bounding_boxes --- torchvision/utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/torchvision/utils.py b/torchvision/utils.py index ce5a00f3765..73552764e67 100644 --- a/torchvision/utils.py +++ b/torchvision/utils.py @@ -208,7 +208,7 @@ def draw_bounding_boxes( if num_boxes == 0: warnings.warn(f"boxes doesn't contain any box. No box was drawn") - return image.to(out_dtype) + return image.to(torch.uint8) if labels is None: labels: Union[List[str], List[None]] = [None] * num_boxes # type: ignore[no-redef] @@ -310,9 +310,11 @@ def draw_segmentation_masks( if colors is not None and num_masks > len(colors): raise ValueError(f"There are more masks ({num_masks}) than colors ({len(colors)})") + out_dtype = torch.uint8 + if num_masks == 0: warnings.warn(f"masks doesn't contain any mask. No mask was drawn") - return image.to(out_dtype) + return image.to(out_dtype) if colors is None: colors = _generate_color_palette(num_masks) @@ -324,8 +326,6 @@ def draw_segmentation_masks( if isinstance(colors[0], tuple) and len(colors[0]) != 3: raise ValueError("It seems that you passed a tuple of colors instead of a list of colors") - out_dtype = torch.uint8 - colors_ = [] for color in colors: if isinstance(color, str): From b51cbb22b4a97002423a3dd8f1284f7d724e4e92 Mon Sep 17 00:00:00 2001 From: LEGRAND Matthieu Date: Thu, 21 Apr 2022 19:14:31 +0200 Subject: [PATCH 03/10] Changing fstring to normal string --- torchvision/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchvision/utils.py b/torchvision/utils.py index 73552764e67..ef1b32e599c 100644 --- a/torchvision/utils.py +++ b/torchvision/utils.py @@ -207,7 +207,7 @@ def draw_bounding_boxes( num_boxes = boxes.shape[0] if num_boxes == 0: - warnings.warn(f"boxes doesn't contain any box. No box was drawn") + warnings.warn("boxes doesn't contain any box. No box was drawn") return image.to(torch.uint8) if labels is None: @@ -313,7 +313,7 @@ def draw_segmentation_masks( out_dtype = torch.uint8 if num_masks == 0: - warnings.warn(f"masks doesn't contain any mask. No mask was drawn") + warnings.warn("masks doesn't contain any mask. No mask was drawn") return image.to(out_dtype) if colors is None: From a1b91cbfa5e1e047dd11a2193ac5ebe5852bd432 Mon Sep 17 00:00:00 2001 From: LEGRAND Matthieu Date: Fri, 22 Apr 2022 13:51:54 +0200 Subject: [PATCH 04/10] Removing unecessary conversion --- torchvision/utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/torchvision/utils.py b/torchvision/utils.py index ef1b32e599c..512f0de4b48 100644 --- a/torchvision/utils.py +++ b/torchvision/utils.py @@ -208,7 +208,7 @@ def draw_bounding_boxes( if num_boxes == 0: warnings.warn("boxes doesn't contain any box. No box was drawn") - return image.to(torch.uint8) + return image if labels is None: labels: Union[List[str], List[None]] = [None] * num_boxes # type: ignore[no-redef] @@ -310,11 +310,9 @@ def draw_segmentation_masks( if colors is not None and num_masks > len(colors): raise ValueError(f"There are more masks ({num_masks}) than colors ({len(colors)})") - out_dtype = torch.uint8 - if num_masks == 0: warnings.warn("masks doesn't contain any mask. No mask was drawn") - return image.to(out_dtype) + return image if colors is None: colors = _generate_color_palette(num_masks) @@ -337,6 +335,8 @@ def draw_segmentation_masks( for mask, color in zip(masks, colors_): img_to_draw[:, mask] = color[:, None] + out_dtype = torch.uint8 + out = image * (1 - alpha) + img_to_draw * alpha return out.to(out_dtype) From d9e4ee1632967724b2224bf9cb85e34426b509a8 Mon Sep 17 00:00:00 2001 From: LEGRAND Matthieu Date: Fri, 22 Apr 2022 14:06:10 +0200 Subject: [PATCH 05/10] Adding test for the change --- test/test_utils.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/test/test_utils.py b/test/test_utils.py index 727208ec16c..7af16d6be2b 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -4,6 +4,8 @@ import tempfile from io import BytesIO +import warnings + import numpy as np import pytest import torch @@ -120,7 +122,6 @@ def test_draw_boxes_colors(colors): img = torch.full((3, 100, 100), 0, dtype=torch.uint8) utils.draw_bounding_boxes(img, boxes, fill=False, width=7, colors=colors) - def test_draw_boxes_vanilla(): img = torch.full((3, 100, 100), 0, dtype=torch.uint8) img_cp = img.clone() @@ -176,6 +177,14 @@ def test_draw_boxes_warning(): utils.draw_bounding_boxes(img, boxes, font_size=11) +def test_draw_no_boxes(): + img = torch.full((3, 100, 100), 0, dtype=torch.uint8) + boxes = torch.full((0,4), 0, dtype=torch.float) + with pytest.warns(UserWarning, match=re.escape("boxes doesn't contain any box. No box was drawn")): + res = utils.draw_bounding_boxes(img, boxes) + assert res.eq(img) + + @pytest.mark.parametrize( "colors", [ From 7adff47e79493aca769b9ee46cdc5228826f40a5 Mon Sep 17 00:00:00 2001 From: LEGRAND Matthieu Date: Fri, 22 Apr 2022 14:08:17 +0200 Subject: [PATCH 06/10] Adding a test for draw seqmentation mask --- test/test_utils.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/test/test_utils.py b/test/test_utils.py index 7af16d6be2b..59a5da4274d 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -274,6 +274,12 @@ def test_draw_segmentation_masks_errors(): bad_colors = ("red", "blue") # should be a list utils.draw_segmentation_masks(image=img, masks=masks, colors=bad_colors) +def test_draw_no_segmention_mask(): + img = torch.full((3, 100, 100), 0, dtype=torch.uint8) + boxes = torch.full((0,4), 0, dtype=torch.float) + with pytest.warns(UserWarning, match=re.escape("masks doesn't contain any mask. No mask was drawn")): + res = utils.draw_segmentation_masks(img, boxes) + assert res.eq(img) def test_draw_keypoints_vanilla(): # Keypoints is declared on top as global variable From 462ba9316137d225e4c6d893dfd975188cd48457 Mon Sep 17 00:00:00 2001 From: LEGRAND Matthieu Date: Fri, 22 Apr 2022 14:11:02 +0200 Subject: [PATCH 07/10] Fixing small mistake --- torchvision/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchvision/utils.py b/torchvision/utils.py index 512f0de4b48..9c7a44b3959 100644 --- a/torchvision/utils.py +++ b/torchvision/utils.py @@ -324,6 +324,8 @@ def draw_segmentation_masks( if isinstance(colors[0], tuple) and len(colors[0]) != 3: raise ValueError("It seems that you passed a tuple of colors instead of a list of colors") + out_dtype = torch.uint8 + colors_ = [] for color in colors: if isinstance(color, str): @@ -335,8 +337,6 @@ def draw_segmentation_masks( for mask, color in zip(masks, colors_): img_to_draw[:, mask] = color[:, None] - out_dtype = torch.uint8 - out = image * (1 - alpha) + img_to_draw * alpha return out.to(out_dtype) From 32a12b5e755ab1643c9386e5527dada13532bdd3 Mon Sep 17 00:00:00 2001 From: LEGRAND Matthieu Date: Fri, 22 Apr 2022 14:17:24 +0200 Subject: [PATCH 08/10] Fixing an error in the tests --- test/test_utils.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/test/test_utils.py b/test/test_utils.py index 59a5da4274d..4e8cc1374df 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -182,8 +182,8 @@ def test_draw_no_boxes(): boxes = torch.full((0,4), 0, dtype=torch.float) with pytest.warns(UserWarning, match=re.escape("boxes doesn't contain any box. No box was drawn")): res = utils.draw_bounding_boxes(img, boxes) - assert res.eq(img) - + # Check that the function didnt change the image + assert res.eq(img).all() @pytest.mark.parametrize( "colors", @@ -276,10 +276,11 @@ def test_draw_segmentation_masks_errors(): def test_draw_no_segmention_mask(): img = torch.full((3, 100, 100), 0, dtype=torch.uint8) - boxes = torch.full((0,4), 0, dtype=torch.float) + masks = torch.full((0, 100, 100), 0, dtype=torch.bool) with pytest.warns(UserWarning, match=re.escape("masks doesn't contain any mask. No mask was drawn")): - res = utils.draw_segmentation_masks(img, boxes) - assert res.eq(img) + res = utils.draw_segmentation_masks(img, masks) + # Check that the function didnt change the image + assert res.eq(img).all() def test_draw_keypoints_vanilla(): # Keypoints is declared on top as global variable From f527d4e33d6883d86c7691815e1b9de93557417e Mon Sep 17 00:00:00 2001 From: LEGRAND Matthieu Date: Thu, 28 Apr 2022 14:59:28 +0200 Subject: [PATCH 09/10] removing useless imports --- test/test_utils.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/test/test_utils.py b/test/test_utils.py index 4e8cc1374df..b643504d2fb 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -4,8 +4,6 @@ import tempfile from io import BytesIO -import warnings - import numpy as np import pytest import torch From 0467907d21cf229d3bfa8b548980e29383ea972f Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Mon, 23 May 2022 10:41:09 +0100 Subject: [PATCH 10/10] ufmt --- test/test_utils.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/test/test_utils.py b/test/test_utils.py index b643504d2fb..fab8c5fc082 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -120,6 +120,7 @@ def test_draw_boxes_colors(colors): img = torch.full((3, 100, 100), 0, dtype=torch.uint8) utils.draw_bounding_boxes(img, boxes, fill=False, width=7, colors=colors) + def test_draw_boxes_vanilla(): img = torch.full((3, 100, 100), 0, dtype=torch.uint8) img_cp = img.clone() @@ -177,12 +178,13 @@ def test_draw_boxes_warning(): def test_draw_no_boxes(): img = torch.full((3, 100, 100), 0, dtype=torch.uint8) - boxes = torch.full((0,4), 0, dtype=torch.float) + boxes = torch.full((0, 4), 0, dtype=torch.float) with pytest.warns(UserWarning, match=re.escape("boxes doesn't contain any box. No box was drawn")): res = utils.draw_bounding_boxes(img, boxes) # Check that the function didnt change the image assert res.eq(img).all() + @pytest.mark.parametrize( "colors", [ @@ -272,6 +274,7 @@ def test_draw_segmentation_masks_errors(): bad_colors = ("red", "blue") # should be a list utils.draw_segmentation_masks(image=img, masks=masks, colors=bad_colors) + def test_draw_no_segmention_mask(): img = torch.full((3, 100, 100), 0, dtype=torch.uint8) masks = torch.full((0, 100, 100), 0, dtype=torch.bool) @@ -280,6 +283,7 @@ def test_draw_no_segmention_mask(): # Check that the function didnt change the image assert res.eq(img).all() + def test_draw_keypoints_vanilla(): # Keypoints is declared on top as global variable keypoints_cp = keypoints.clone()