-
Couldn't load subscription status.
- Fork 7.2k
Description
The standard rule for dtype support for images and videos is:
- All floating point and integer tensors are supported.
- Floating point tensors are valid in the range
[0.0, 1.0]and integer tensors in[0, torch.iinfo(dtype).max](this is currently under review since there were a few cases, where this was not true or simply not handled. See Don't hardcode 255 unless uint8 is enforced #6825)
However we have currently two kernels that only support uint8 images or videos:
vision/torchvision/prototype/transforms/functional/_color.py
Lines 373 to 375 in c84dbfa
def equalize_image_tensor(image: torch.Tensor) -> torch.Tensor: if image.dtype != torch.uint8: raise TypeError(f"Only torch.uint8 image tensors are supported, but found {image.dtype}") vision/torchvision/transforms/functional_tensor.py
Lines 788 to 789 in c84dbfa
if img.dtype != torch.uint8: raise TypeError(f"Only torch.uint8 image tensors are supported, but found {img.dtype}")
This also holds for transforms v1 so this is not a problem of the new API.
One consequence of that is that AA transforms are only supported for uint8 images
vision/torchvision/transforms/autoaugment.py
Lines 104 to 107 in c84dbfa
| class AutoAugment(torch.nn.Module): | |
| r"""AutoAugment data augmentation method based on | |
| `"AutoAugment: Learning Augmentation Strategies from Data" <https://arxiv.org/pdf/1805.09501.pdf>`_. | |
| If the image is torch Tensor, it should be of type torch.uint8, and it is expected |
since both
vision/torchvision/transforms/autoaugment.py
Lines 76 to 77 in c84dbfa
| elif op_name == "Posterize": | |
| img = F.posterize(img, int(magnitude)) |
and
vision/torchvision/transforms/autoaugment.py
Lines 82 to 83 in c84dbfa
| elif op_name == "Equalize": | |
| img = F.equalize(img) |
are used.
One possible way of mitigating this to simply have a convert_dtype(image, torch.uint8) in the beginning and converting back after computation.
That is probably needed for equalize since we recently switched away from the histogram ops of torch towards our "custom" implementation to enable batch processing (#6757). However, this relies on the fact that the input is an integer and in its current form even on uint8 due to some hardcoded constants.
For posterize I think it is fairly easy to provide the same functionality for float inputs directly without going through a dtype conversion first.