-
Notifications
You must be signed in to change notification settings - Fork 7.2k
Description
On uint8 tensors, Resize() currently converts the input image to and from torch.float to pass it to interpolate(), because interpolate() didn't support native uint8 inputs in the past. This is suboptimal.
@vfdev-5 and I have recently implemented native uint8 support for interpolate(mode="bilinear") in pytorch/pytorch#90771 and pytorch/pytorch#96848.
We should integrate this native uint8 support into torchvision's Resize(). Benchmarks below show that such integration could lead to at least 3X improvement on Resize()'s time, which saves 1ms per image and a 30% improvement of the total pipeline time for a typical classification pipeline (including auto-augment, which is the next bottleneck). This would make the Tensor / DataPoint backend significantly faster than PIL.
Some current challenges before integrations are:
- improvements for native uint8 are mostly for AVX2 archs. Compared to current
Resize()implem (float), is the perf still OK on archs that don’t support AVX2? First: need to identify whether those non-AVX2 targets are critical or not. - BC: Although more strictly correct, the uint8 native path may have 1-off differences with current float path. Mitigation: only integrate native uint8 into V2 Resize(), where BC commitments are looser.
Benchmarks made with @pmeier's pmeier/detection-reference-benchmark@0ae9027 and with the following patch
+class ResizeUint8(torch.nn.Module):
+ def __init__(self, force_channels_last):
+ super().__init__()
+ self.force_channels_last = force_channels_last
+
+ def forward(self, img):
+ img = img.unsqueeze(0)
+ if self.force_channels_last:
+ img = img.contiguous(memory_format=torch.channels_last)
+ return torch.nn.functional.interpolate(img, size=[223, 223], mode="bilinear", antialias=True, align_corners=None).squeeze(0)
+
def classification_complex_pipeline_builder(*, input_type, api_version):
if input_type == "Datapoint" and api_version == "v1":
return None
@@ -94,9 +106,15 @@ def classification_complex_pipeline_builder(*, input_type, api_version):
if api_version == "v1":
transforms = transforms_v1
RandomResizedCropWithoutResize = RandomResizedCropWithoutResizeV1
+ resize = transforms.Resize(223, antialias=True)
elif api_version == "v2":
transforms = transforms_v2
RandomResizedCropWithoutResize = RandomResizedCropWithoutResizeV2
+ if input_type in ("Datapoint", "Tensor"):
+ # resize = ResizeUint8(force_channels_last=False)
+ resize = transforms.Resize(223, antialias=True)
+ else:
+ resize = transforms.Resize(223, antialias=True)
else:
raise RuntimeError(f"Got {api_version=}")
@@ -106,11 +124,14 @@ def classification_complex_pipeline_builder(*, input_type, api_version):
pipeline.append(transforms.PILToTensor())
elif input_type == "Datapoint":
pipeline.append(transforms.ToImageTensor())
+
+
pipeline.extend(
[
RandomResizedCropWithoutResize(224),
- transforms.Resize(224, antialias=True),
+ # transforms.Resize(223, antialias=True),
+ resize,
transforms.RandomHorizontalFlip(p=0.5),
transforms.AutoAugment(transforms.AutoAugmentPolicy.IMAGENET),
]
Without uint8 native support:
############################################################
classification-complex
############################################################
input_type='Tensor', api_version='v1'
Results computed for 1_000 samples
median std
PILToTensor 258 µs +- 24 µs
RandomResizedCropWithoutResizeV1 111 µs +- 22 µs
Resize 1238 µs +- 311 µs
RandomHorizontalFlip 53 µs +- 21 µs
AutoAugment 1281 µs +- 840 µs
RandomErasing 31 µs +- 66 µs
ConvertImageDtype 120 µs +- 13 µs
Normalize 186 µs +- 23 µs
total 3278 µs
------------------------------------------------------------
input_type='Tensor', api_version='v2'
Results computed for 1_000 samples
median std
PILToTensor 271 µs +- 21 µs
RandomResizedCropWithoutResizeV2 113 µs +- 17 µs
Resize 1226 µs +- 304 µs
RandomHorizontalFlip 64 µs +- 24 µs
AutoAugment 1099 µs +- 738 µs
RandomErasing 39 µs +- 68 µs
ConvertDtype 96 µs +- 12 µs
Normalize 150 µs +- 17 µs
total 3057 µs
------------------------------------------------------------
input_type='PIL', api_version='v1'
Results computed for 1_000 samples
median std
RandomResizedCropWithoutResizeV1 162 µs +- 27 µs
Resize 787 µs +- 186 µs
RandomHorizontalFlip 53 µs +- 29 µs
AutoAugment 585 µs +- 342 µs
PILToTensor 96 µs +- 9 µs
RandomErasing 32 µs +- 65 µs
ConvertImageDtype 125 µs +- 14 µs
Normalize 850 µs +- 83 µs
total 2688 µs
------------------------------------------------------------
input_type='PIL', api_version='v2'
Results computed for 1_000 samples
median std
RandomResizedCropWithoutResizeV2 166 µs +- 26 µs
Resize 783 µs +- 185 µs
RandomHorizontalFlip 61 µs +- 33 µs
AutoAugment 489 µs +- 355 µs
PILToTensor 115 µs +- 9 µs
RandomErasing 37 µs +- 65 µs
ConvertDtype 101 µs +- 11 µs
Normalize 825 µs +- 84 µs
total 2577 µs
------------------------------------------------------------
input_type='Datapoint', api_version='v2'
Results computed for 1_000 samples
median std
ToImageTensor 284 µs +- 22 µs
RandomResizedCropWithoutResizeV2 119 µs +- 17 µs
Resize 1223 µs +- 302 µs
RandomHorizontalFlip 62 µs +- 29 µs
AutoAugment 1100 µs +- 625 µs
RandomErasing 39 µs +- 72 µs
ConvertDtype 106 µs +- 13 µs
Normalize 155 µs +- 16 µs
total 3089 µs
------------------------------------------------------------
Summaries
v2 / v1
Tensor 0.93
PIL 0.96
[a] [b] [c] [d] [e]
Tensor, v1, [a] 1.00 1.07 1.22 1.27 1.06
Tensor, v2, [b] 0.93 1.00 1.14 1.19 0.99
PIL, v1, [c] 0.82 0.88 1.00 1.04 0.87
PIL, v2, [d] 0.79 0.84 0.96 1.00 0.83
Datapoint, v2, [e] 0.94 1.01 1.15 1.20 1.00
Slowdown as row / col
With uint8 native support for TensorV2 and DatapointV2:
############################################################
classification-complex
############################################################
input_type='Tensor', api_version='v1'
Results computed for 1_000 samples
median std
PILToTensor 255 µs +- 21 µs
RandomResizedCropWithoutResizeV1 110 µs +- 22 µs
Resize 1230 µs +- 315 µs
RandomHorizontalFlip 47 µs +- 24 µs
AutoAugment 1269 µs +- 870 µs
RandomErasing 31 µs +- 66 µs
ConvertImageDtype 121 µs +- 13 µs
Normalize 186 µs +- 23 µs
total 3249 µs
------------------------------------------------------------
input_type='Tensor', api_version='v2'
Results computed for 1_000 samples
median std
PILToTensor 270 µs +- 20 µs
RandomResizedCropWithoutResizeV2 110 µs +- 17 µs
ResizeUint8 402 µs +- 109 µs
RandomHorizontalFlip 66 µs +- 24 µs
AutoAugment 996 µs +- 539 µs
RandomErasing 39 µs +- 64 µs
ConvertDtype 81 µs +- 10 µs
Normalize 134 µs +- 14 µs
total 2099 µs
------------------------------------------------------------
input_type='PIL', api_version='v1'
Results computed for 1_000 samples
median std
RandomResizedCropWithoutResizeV1 161 µs +- 28 µs
Resize 779 µs +- 186 µs
RandomHorizontalFlip 53 µs +- 29 µs
AutoAugment 576 µs +- 339 µs
PILToTensor 93 µs +- 8 µs
RandomErasing 31 µs +- 64 µs
ConvertImageDtype 123 µs +- 13 µs
Normalize 843 µs +- 82 µs
total 2661 µs
------------------------------------------------------------
input_type='PIL', api_version='v2'
Results computed for 1_000 samples
median std
RandomResizedCropWithoutResizeV2 163 µs +- 26 µs
Resize 788 µs +- 180 µs
RandomHorizontalFlip 62 µs +- 33 µs
AutoAugment 492 µs +- 355 µs
PILToTensor 112 µs +- 9 µs
RandomErasing 37 µs +- 64 µs
ConvertDtype 100 µs +- 11 µs
Normalize 826 µs +- 86 µs
total 2580 µs
------------------------------------------------------------
input_type='Datapoint', api_version='v2'
Results computed for 1_000 samples
median std
ToImageTensor 284 µs +- 22 µs
RandomResizedCropWithoutResizeV2 118 µs +- 17 µs
ResizeUint8 410 µs +- 109 µs
RandomHorizontalFlip 68 µs +- 23 µs
AutoAugment 994 µs +- 542 µs
RandomErasing 38 µs +- 63 µs
ConvertDtype 81 µs +- 10 µs
Normalize 133 µs +- 14 µs
total 2127 µs
------------------------------------------------------------
Summaries
v2 / v1
Tensor 0.65
PIL 0.97
[a] [b] [c] [d] [e]
Tensor, v1, [a] 1.00 1.55 1.22 1.26 1.53
Tensor, v2, [b] 0.65 1.00 0.79 0.81 0.99
PIL, v1, [c] 0.82 1.27 1.00 1.03 1.25
PIL, v2, [d] 0.79 1.23 0.97 1.00 1.21
Datapoint, v2, [e] 0.65 1.01 0.80 0.82 1.00
Slowdown as row / col