|
13 | 13 |
|
14 | 14 | from common_utils import TransformsTester |
15 | 15 |
|
16 | | -from typing import Dict, List, Tuple |
| 16 | +from typing import Dict, List, Sequence, Tuple |
17 | 17 |
|
18 | 18 |
|
19 | 19 | NEAREST, BILINEAR, BICUBIC = InterpolationMode.NEAREST, InterpolationMode.BILINEAR, InterpolationMode.BICUBIC |
@@ -409,46 +409,58 @@ def test_resize(self): |
409 | 409 | batch_tensors = batch_tensors.to(dt) |
410 | 410 |
|
411 | 411 | for size in [32, 26, [32, ], [32, 32], (32, 32), [26, 35]]: |
412 | | - for interpolation in [BILINEAR, BICUBIC, NEAREST]: |
413 | | - resized_tensor = F.resize(tensor, size=size, interpolation=interpolation) |
414 | | - resized_pil_img = F.resize(pil_img, size=size, interpolation=interpolation) |
415 | | - |
416 | | - self.assertEqual( |
417 | | - resized_tensor.size()[1:], resized_pil_img.size[::-1], msg="{}, {}".format(size, interpolation) |
418 | | - ) |
419 | | - |
420 | | - if interpolation not in [NEAREST, ]: |
421 | | - # We can not check values if mode = NEAREST, as results are different |
422 | | - # E.g. resized_tensor = [[a, a, b, c, d, d, e, ...]] |
423 | | - # E.g. resized_pil_img = [[a, b, c, c, d, e, f, ...]] |
424 | | - resized_tensor_f = resized_tensor |
425 | | - # we need to cast to uint8 to compare with PIL image |
426 | | - if resized_tensor_f.dtype == torch.uint8: |
427 | | - resized_tensor_f = resized_tensor_f.to(torch.float) |
428 | | - |
429 | | - # Pay attention to high tolerance for MAE |
430 | | - self.approxEqualTensorToPIL( |
431 | | - resized_tensor_f, resized_pil_img, tol=8.0, msg="{}, {}".format(size, interpolation) |
| 412 | + for max_size in (None, 33, 40, 1000): |
| 413 | + if max_size is not None and isinstance(size, Sequence) and len(size) != 1: |
| 414 | + continue # unsupported, see assertRaises below |
| 415 | + for interpolation in [BILINEAR, BICUBIC, NEAREST]: |
| 416 | + resized_tensor = F.resize(tensor, size=size, interpolation=interpolation, max_size=max_size) |
| 417 | + resized_pil_img = F.resize(pil_img, size=size, interpolation=interpolation, max_size=max_size) |
| 418 | + |
| 419 | + self.assertEqual( |
| 420 | + resized_tensor.size()[1:], resized_pil_img.size[::-1], |
| 421 | + msg="{}, {}".format(size, interpolation) |
432 | 422 | ) |
433 | 423 |
|
434 | | - if isinstance(size, int): |
435 | | - script_size = [size, ] |
436 | | - else: |
437 | | - script_size = size |
| 424 | + if interpolation not in [NEAREST, ]: |
| 425 | + # We can not check values if mode = NEAREST, as results are different |
| 426 | + # E.g. resized_tensor = [[a, a, b, c, d, d, e, ...]] |
| 427 | + # E.g. resized_pil_img = [[a, b, c, c, d, e, f, ...]] |
| 428 | + resized_tensor_f = resized_tensor |
| 429 | + # we need to cast to uint8 to compare with PIL image |
| 430 | + if resized_tensor_f.dtype == torch.uint8: |
| 431 | + resized_tensor_f = resized_tensor_f.to(torch.float) |
| 432 | + |
| 433 | + # Pay attention to high tolerance for MAE |
| 434 | + self.approxEqualTensorToPIL( |
| 435 | + resized_tensor_f, resized_pil_img, tol=8.0, msg="{}, {}".format(size, interpolation) |
| 436 | + ) |
438 | 437 |
|
439 | | - resize_result = script_fn(tensor, size=script_size, interpolation=interpolation) |
440 | | - self.assertTrue(resized_tensor.equal(resize_result), msg="{}, {}".format(size, interpolation)) |
| 438 | + if isinstance(size, int): |
| 439 | + script_size = [size, ] |
| 440 | + else: |
| 441 | + script_size = size |
441 | 442 |
|
442 | | - self._test_fn_on_batch( |
443 | | - batch_tensors, F.resize, size=script_size, interpolation=interpolation |
444 | | - ) |
| 443 | + resize_result = script_fn(tensor, size=script_size, interpolation=interpolation, |
| 444 | + max_size=max_size) |
| 445 | + self.assertTrue(resized_tensor.equal(resize_result), msg="{}, {}".format(size, interpolation)) |
| 446 | + |
| 447 | + self._test_fn_on_batch( |
| 448 | + batch_tensors, F.resize, size=script_size, interpolation=interpolation, max_size=max_size |
| 449 | + ) |
445 | 450 |
|
446 | 451 | # assert changed type warning |
447 | 452 | with self.assertWarnsRegex(UserWarning, r"Argument interpolation should be of type InterpolationMode"): |
448 | 453 | res1 = F.resize(tensor, size=32, interpolation=2) |
449 | 454 | res2 = F.resize(tensor, size=32, interpolation=BILINEAR) |
450 | 455 | self.assertTrue(res1.equal(res2)) |
451 | 456 |
|
| 457 | + for img in (tensor, pil_img): |
| 458 | + exp_msg = "max_size should only be passed if size specifies the length of the smaller edge" |
| 459 | + with self.assertRaisesRegex(ValueError, exp_msg): |
| 460 | + F.resize(img, size=(32, 34), max_size=35) |
| 461 | + with self.assertRaisesRegex(ValueError, "max_size = 32 must be strictly greater"): |
| 462 | + F.resize(img, size=32, max_size=32) |
| 463 | + |
452 | 464 | def test_resized_crop(self): |
453 | 465 | # test values of F.resized_crop in several cases: |
454 | 466 | # 1) resize to the same size, crop to the same size => should be identity |
|
0 commit comments