From a0fbc5f835cb8fc24b78cc3158ac4666a0fe0cf4 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 25 May 2023 10:59:06 +0000 Subject: [PATCH 1/5] Allow classification references to use the tensor backend --- references/classification/presets.py | 47 +++++++++++++++++++++------- references/classification/train.py | 9 ++++-- torchvision/transforms/_presets.py | 8 ++++- 3 files changed, 50 insertions(+), 14 deletions(-) diff --git a/references/classification/presets.py b/references/classification/presets.py index 5d1bf1cc714..e4e17a7190d 100644 --- a/references/classification/presets.py +++ b/references/classification/presets.py @@ -16,8 +16,16 @@ def __init__( ra_magnitude=9, augmix_severity=3, random_erase_prob=0.0, + backend="pil", ): - trans = [transforms.RandomResizedCrop(crop_size, interpolation=interpolation)] + trans = [] + backend = backend.lower() + if backend == "tensor": + trans.append(transforms.PILToTensor()) + else: + assert backend == "pil" + + trans.append(transforms.RandomResizedCrop(crop_size, interpolation=interpolation, antialias=True)) if hflip_prob > 0: trans.append(transforms.RandomHorizontalFlip(hflip_prob)) if auto_augment_policy is not None: @@ -30,9 +38,13 @@ def __init__( else: aa_policy = autoaugment.AutoAugmentPolicy(auto_augment_policy) trans.append(autoaugment.AutoAugment(policy=aa_policy, interpolation=interpolation)) + + if backend == "pil": + # Note: we could also just use pure tensors? + trans.append(transforms.PILToTensor()) + trans.extend( [ - transforms.PILToTensor(), transforms.ConvertImageDtype(torch.float), transforms.Normalize(mean=mean, std=std), ] @@ -55,17 +67,30 @@ def __init__( mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), interpolation=InterpolationMode.BILINEAR, + backend="pil", ): + trans = [] - self.transforms = transforms.Compose( - [ - transforms.Resize(resize_size, interpolation=interpolation), - transforms.CenterCrop(crop_size), - transforms.PILToTensor(), - transforms.ConvertImageDtype(torch.float), - transforms.Normalize(mean=mean, std=std), - ] - ) + backend = backend.lower() + if backend == "tensor": + trans.append(transforms.PILToTensor()) + else: + assert backend == "pil" + + trans += [ + transforms.Resize(resize_size, interpolation=interpolation, antialias=True), + transforms.CenterCrop(crop_size), + ] + + if backend == "pil": + trans.append(transforms.PILToTensor()) + + trans += [ + transforms.ConvertImageDtype(torch.float), + transforms.Normalize(mean=mean, std=std), + ] + + self.transforms = transforms.Compose(trans) def __call__(self, img): return self.transforms(img) diff --git a/references/classification/train.py b/references/classification/train.py index 10ba22bce03..7b8072dbce8 100644 --- a/references/classification/train.py +++ b/references/classification/train.py @@ -143,6 +143,7 @@ def load_data(traindir, valdir, args): random_erase_prob=random_erase_prob, ra_magnitude=ra_magnitude, augmix_severity=augmix_severity, + backend=args.backend, ), ) if args.cache_dataset: @@ -160,10 +161,13 @@ def load_data(traindir, valdir, args): else: if args.weights and args.test_only: weights = torchvision.models.get_weight(args.weights) - preprocessing = weights.transforms() + preprocessing = weights.transforms(antialias=True, backend=args.backend) else: preprocessing = presets.ClassificationPresetEval( - crop_size=val_crop_size, resize_size=val_resize_size, interpolation=interpolation + crop_size=val_crop_size, + resize_size=val_resize_size, + interpolation=interpolation, + backend=args.backend, ) dataset_test = torchvision.datasets.ImageFolder( @@ -507,6 +511,7 @@ def get_args_parser(add_help=True): "--ra-reps", default=3, type=int, help="number of repetitions for Repeated Augmentation (default: 3)" ) parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load") + parser.add_argument("--backend", default="PIL", type=str, help="PIL or tensor - case insensitive") return parser diff --git a/torchvision/transforms/_presets.py b/torchvision/transforms/_presets.py index ccbe425f2ac..0143ccab05f 100644 --- a/torchvision/transforms/_presets.py +++ b/torchvision/transforms/_presets.py @@ -45,6 +45,7 @@ def __init__( std: Tuple[float, ...] = (0.229, 0.224, 0.225), interpolation: InterpolationMode = InterpolationMode.BILINEAR, antialias: Optional[Union[str, bool]] = "warn", + backend="pil", ) -> None: super().__init__() self.crop_size = [crop_size] @@ -53,11 +54,16 @@ def __init__( self.std = list(std) self.interpolation = interpolation self.antialias = antialias + self.backend = backend.lower() + if self.backend not in ("pil", "tensor"): + raise ValueError(f"backend parameter must be pil or tensor, got {backend}") def forward(self, img: Tensor) -> Tensor: + if self.backend == "tensor": + img = F.pil_to_tensor(img) img = F.resize(img, self.resize_size, interpolation=self.interpolation, antialias=self.antialias) img = F.center_crop(img, self.crop_size) - if not isinstance(img, Tensor): + if self.backend == "pil": img = F.pil_to_tensor(img) img = F.convert_image_dtype(img, torch.float) img = F.normalize(img, mean=self.mean, std=self.std) From 79c4cda4220d8eab74f37e052c2ce236781af6f4 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 25 May 2023 13:02:06 +0100 Subject: [PATCH 2/5] Update references/classification/presets.py Co-authored-by: Philip Meier --- references/classification/presets.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/references/classification/presets.py b/references/classification/presets.py index e4e17a7190d..8f804dff884 100644 --- a/references/classification/presets.py +++ b/references/classification/presets.py @@ -22,8 +22,8 @@ def __init__( backend = backend.lower() if backend == "tensor": trans.append(transforms.PILToTensor()) - else: - assert backend == "pil" + elif backend != "pil": + raise ValueError("backend can be 'tensor' or 'pil', but got {backend}") trans.append(transforms.RandomResizedCrop(crop_size, interpolation=interpolation, antialias=True)) if hflip_prob > 0: From 4746ecec729e2b1fa5a1d59e736886049417e5d8 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 25 May 2023 12:02:26 +0000 Subject: [PATCH 3/5] remove comment --- references/classification/presets.py | 1 - 1 file changed, 1 deletion(-) diff --git a/references/classification/presets.py b/references/classification/presets.py index 8f804dff884..2fc5796a062 100644 --- a/references/classification/presets.py +++ b/references/classification/presets.py @@ -40,7 +40,6 @@ def __init__( trans.append(autoaugment.AutoAugment(policy=aa_policy, interpolation=interpolation)) if backend == "pil": - # Note: we could also just use pure tensors? trans.append(transforms.PILToTensor()) trans.extend( From de3c4d4c2d992d52de4c426383980c89d5347efb Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 25 May 2023 12:12:16 +0000 Subject: [PATCH 4/5] Address comments --- references/classification/train.py | 6 +++++- torchvision/transforms/_presets.py | 8 +------- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/references/classification/train.py b/references/classification/train.py index 7b8072dbce8..6281a19719f 100644 --- a/references/classification/train.py +++ b/references/classification/train.py @@ -7,6 +7,7 @@ import torch import torch.utils.data import torchvision +import torchvision.transforms import transforms import utils from sampler import RASampler @@ -161,7 +162,10 @@ def load_data(traindir, valdir, args): else: if args.weights and args.test_only: weights = torchvision.models.get_weight(args.weights) - preprocessing = weights.transforms(antialias=True, backend=args.backend) + preprocessing = weights.transforms(antialias=True) + if args.backend.lower() == "tensor": + preprocessing = torchvision.transforms.Compose([torchvision.transforms.PILToTensor(), preprocessing]) + else: preprocessing = presets.ClassificationPresetEval( crop_size=val_crop_size, diff --git a/torchvision/transforms/_presets.py b/torchvision/transforms/_presets.py index 0143ccab05f..ccbe425f2ac 100644 --- a/torchvision/transforms/_presets.py +++ b/torchvision/transforms/_presets.py @@ -45,7 +45,6 @@ def __init__( std: Tuple[float, ...] = (0.229, 0.224, 0.225), interpolation: InterpolationMode = InterpolationMode.BILINEAR, antialias: Optional[Union[str, bool]] = "warn", - backend="pil", ) -> None: super().__init__() self.crop_size = [crop_size] @@ -54,16 +53,11 @@ def __init__( self.std = list(std) self.interpolation = interpolation self.antialias = antialias - self.backend = backend.lower() - if self.backend not in ("pil", "tensor"): - raise ValueError(f"backend parameter must be pil or tensor, got {backend}") def forward(self, img: Tensor) -> Tensor: - if self.backend == "tensor": - img = F.pil_to_tensor(img) img = F.resize(img, self.resize_size, interpolation=self.interpolation, antialias=self.antialias) img = F.center_crop(img, self.crop_size) - if self.backend == "pil": + if not isinstance(img, Tensor): img = F.pil_to_tensor(img) img = F.convert_image_dtype(img, torch.float) img = F.normalize(img, mean=self.mean, std=self.std) From a618a5f931ccf0fb76fbfc55eedcae9fde31a050 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Fri, 26 May 2023 11:05:16 +0000 Subject: [PATCH 5/5] Address comments --- references/classification/presets.py | 4 ++-- references/classification/train.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/references/classification/presets.py b/references/classification/presets.py index 2fc5796a062..a710f92ae88 100644 --- a/references/classification/presets.py +++ b/references/classification/presets.py @@ -23,7 +23,7 @@ def __init__( if backend == "tensor": trans.append(transforms.PILToTensor()) elif backend != "pil": - raise ValueError("backend can be 'tensor' or 'pil', but got {backend}") + raise ValueError(f"backend can be 'tensor' or 'pil', but got {backend}") trans.append(transforms.RandomResizedCrop(crop_size, interpolation=interpolation, antialias=True)) if hflip_prob > 0: @@ -74,7 +74,7 @@ def __init__( if backend == "tensor": trans.append(transforms.PILToTensor()) else: - assert backend == "pil" + raise ValueError(f"backend can be 'tensor' or 'pil', but got {backend}") trans += [ transforms.Resize(resize_size, interpolation=interpolation, antialias=True), diff --git a/references/classification/train.py b/references/classification/train.py index 6281a19719f..0c1a301453d 100644 --- a/references/classification/train.py +++ b/references/classification/train.py @@ -163,7 +163,7 @@ def load_data(traindir, valdir, args): if args.weights and args.test_only: weights = torchvision.models.get_weight(args.weights) preprocessing = weights.transforms(antialias=True) - if args.backend.lower() == "tensor": + if args.backend == "tensor": preprocessing = torchvision.transforms.Compose([torchvision.transforms.PILToTensor(), preprocessing]) else: @@ -515,7 +515,7 @@ def get_args_parser(add_help=True): "--ra-reps", default=3, type=int, help="number of repetitions for Repeated Augmentation (default: 3)" ) parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load") - parser.add_argument("--backend", default="PIL", type=str, help="PIL or tensor - case insensitive") + parser.add_argument("--backend", default="PIL", type=str.lower, help="PIL or tensor - case insensitive") return parser