From d728903c995dd5ba73ba100a1736f14e18f7dd7b Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 4 Mar 2021 17:43:41 +0000 Subject: [PATCH] use parameterized on test_resize --- .../unittest/linux/scripts/environment.yml | 1 + .circleci/unittest/linux/scripts/install.sh | 2 +- .../unittest/windows/scripts/environment.yml | 1 + .circleci/unittest/windows/scripts/install.sh | 2 +- test/test_functional_tensor.py | 99 ++++++++++--------- 5 files changed, 59 insertions(+), 46 deletions(-) diff --git a/.circleci/unittest/linux/scripts/environment.yml b/.circleci/unittest/linux/scripts/environment.yml index dcad1abfa31..2dbf56e6975 100644 --- a/.circleci/unittest/linux/scripts/environment.yml +++ b/.circleci/unittest/linux/scripts/environment.yml @@ -15,4 +15,5 @@ dependencies: - future - pillow>=4.1.1 - scipy + - parameterized - av diff --git a/.circleci/unittest/linux/scripts/install.sh b/.circleci/unittest/linux/scripts/install.sh index 1a3e5c6f4d2..b68865d1708 100755 --- a/.circleci/unittest/linux/scripts/install.sh +++ b/.circleci/unittest/linux/scripts/install.sh @@ -24,7 +24,7 @@ else fi printf "Installing PyTorch with %s\n" "${cudatoolkit}" -conda install -y -c "pytorch-${UPLOAD_CHANNEL}" -c conda-forge "pytorch-${UPLOAD_CHANNEL}"::pytorch "${cudatoolkit}" +conda install -y -c "pytorch-${UPLOAD_CHANNEL}" -c conda-forge "pytorch-${UPLOAD_CHANNEL}"::pytorch "${cudatoolkit}" parameterized printf "* Installing torchvision\n" python setup.py develop diff --git a/.circleci/unittest/windows/scripts/environment.yml b/.circleci/unittest/windows/scripts/environment.yml index b4f32cb3cad..5fb44711f1e 100644 --- a/.circleci/unittest/windows/scripts/environment.yml +++ b/.circleci/unittest/windows/scripts/environment.yml @@ -17,3 +17,4 @@ dependencies: - scipy - av - dataclasses + - parameterized diff --git a/.circleci/unittest/windows/scripts/install.sh b/.circleci/unittest/windows/scripts/install.sh index 9304b4b9b65..8b717503b12 100644 --- a/.circleci/unittest/windows/scripts/install.sh +++ b/.circleci/unittest/windows/scripts/install.sh @@ -26,7 +26,7 @@ else fi printf "Installing PyTorch with %s\n" "${cudatoolkit}" -conda install -y -c "pytorch-${UPLOAD_CHANNEL}" -c conda-forge "pytorch-${UPLOAD_CHANNEL}"::pytorch "${cudatoolkit}" +conda install -y -c "pytorch-${UPLOAD_CHANNEL}" -c conda-forge "pytorch-${UPLOAD_CHANNEL}"::pytorch "${cudatoolkit}" parameterized printf "* Installing torchvision\n" "$this_dir/vc_env_helper.bat" python setup.py develop diff --git a/test/test_functional_tensor.py b/test/test_functional_tensor.py index f1219ff7ce9..93361881546 100644 --- a/test/test_functional_tensor.py +++ b/test/test_functional_tensor.py @@ -1,3 +1,4 @@ +import itertools import os import unittest import colorsys @@ -14,11 +15,19 @@ from common_utils import TransformsTester from typing import Dict, List, Sequence, Tuple +from parameterized import parameterized NEAREST, BILINEAR, BICUBIC = InterpolationMode.NEAREST, InterpolationMode.BILINEAR, InterpolationMode.BICUBIC +def name_func(func, _, params): + # Gives the parametrized test a decent name so they can be selected with pytest -k etc. + # This should be put in some common utils file + return f'{func.__name__}_{"_".join(str(arg) for arg in params.args)}' + + + class Tester(TransformsTester): def setUp(self): @@ -392,62 +401,64 @@ def test_adjust_gamma(self): [{"gamma": g1, "gain": g2} for g1, g2 in zip([0.8, 1.0, 1.2], [0.7, 1.0, 1.3])] ) - def test_resize(self): + @parameterized.expand(list(itertools.product( + [None, torch.float32, torch.float64, torch.float16], + [32, 26, [32, ], [32, 32], (32, 32), [26, 35]], + [None, 33, 40, 1000], + [BILINEAR, BICUBIC, NEAREST] + )), name_func=name_func) + def test_resize(self, dt, size, max_size, interpolation): script_fn = torch.jit.script(F.resize) tensor, pil_img = self._create_data(26, 36, device=self.device) batch_tensors = self._create_data_batch(16, 18, num_samples=4, device=self.device) - for dt in [None, torch.float32, torch.float64, torch.float16]: + if dt == torch.float16 and torch.device(self.device).type == "cpu": + self.skipTest("skip float16 on CPU") + if max_size is not None and isinstance(size, Sequence) and len(size) != 1: + self.skipTest("size must be an int with max_size") - if dt == torch.float16 and torch.device(self.device).type == "cpu": - # skip float16 on CPU case - continue + if dt is not None: + # This is a trivial cast to float of uint8 data to test all cases + tensor = tensor.to(dt) + batch_tensors = batch_tensors.to(dt) - if dt is not None: - # This is a trivial cast to float of uint8 data to test all cases - tensor = tensor.to(dt) - batch_tensors = batch_tensors.to(dt) + resized_tensor = F.resize(tensor, size=size, interpolation=interpolation, max_size=max_size) + resized_pil_img = F.resize(pil_img, size=size, interpolation=interpolation, max_size=max_size) - for size in [32, 26, [32, ], [32, 32], (32, 32), [26, 35]]: - for max_size in (None, 33, 40, 1000): - if max_size is not None and isinstance(size, Sequence) and len(size) != 1: - continue # unsupported, see assertRaises below - for interpolation in [BILINEAR, BICUBIC, NEAREST]: - resized_tensor = F.resize(tensor, size=size, interpolation=interpolation, max_size=max_size) - resized_pil_img = F.resize(pil_img, size=size, interpolation=interpolation, max_size=max_size) - - self.assertEqual( - resized_tensor.size()[1:], resized_pil_img.size[::-1], - msg="{}, {}".format(size, interpolation) - ) + self.assertEqual( + resized_tensor.size()[1:], resized_pil_img.size[::-1], + msg="{}, {}".format(size, interpolation) + ) - if interpolation not in [NEAREST, ]: - # We can not check values if mode = NEAREST, as results are different - # E.g. resized_tensor = [[a, a, b, c, d, d, e, ...]] - # E.g. resized_pil_img = [[a, b, c, c, d, e, f, ...]] - resized_tensor_f = resized_tensor - # we need to cast to uint8 to compare with PIL image - if resized_tensor_f.dtype == torch.uint8: - resized_tensor_f = resized_tensor_f.to(torch.float) - - # Pay attention to high tolerance for MAE - self.approxEqualTensorToPIL( - resized_tensor_f, resized_pil_img, tol=8.0, msg="{}, {}".format(size, interpolation) - ) + if interpolation not in [NEAREST, ]: + # We can not check values if mode = NEAREST, as results are different + # E.g. resized_tensor = [[a, a, b, c, d, d, e, ...]] + # E.g. resized_pil_img = [[a, b, c, c, d, e, f, ...]] + resized_tensor_f = resized_tensor + # we need to cast to uint8 to compare with PIL image + if resized_tensor_f.dtype == torch.uint8: + resized_tensor_f = resized_tensor_f.to(torch.float) + + # Pay attention to high tolerance for MAE + self.approxEqualTensorToPIL( + resized_tensor_f, resized_pil_img, tol=8.0, msg="{}, {}".format(size, interpolation) + ) - if isinstance(size, int): - script_size = [size, ] - else: - script_size = size + if isinstance(size, int): + script_size = [size, ] + else: + script_size = size - resize_result = script_fn(tensor, size=script_size, interpolation=interpolation, - max_size=max_size) - self.assertTrue(resized_tensor.equal(resize_result), msg="{}, {}".format(size, interpolation)) + resize_result = script_fn(tensor, size=script_size, interpolation=interpolation, + max_size=max_size) + self.assertTrue(resized_tensor.equal(resize_result), msg="{}, {}".format(size, interpolation)) - self._test_fn_on_batch( - batch_tensors, F.resize, size=script_size, interpolation=interpolation, max_size=max_size - ) + self._test_fn_on_batch( + batch_tensors, F.resize, size=script_size, interpolation=interpolation, max_size=max_size + ) + def test_resize_errors(self): + tensor, pil_img = self._create_data(26, 36, device=self.device) # assert changed type warning with self.assertWarnsRegex(UserWarning, r"Argument interpolation should be of type InterpolationMode"): res1 = F.resize(tensor, size=32, interpolation=2)