diff --git a/references/classification/presets.py b/references/classification/presets.py index 5d1bf1cc714..a710f92ae88 100644 --- a/references/classification/presets.py +++ b/references/classification/presets.py @@ -16,8 +16,16 @@ def __init__( ra_magnitude=9, augmix_severity=3, random_erase_prob=0.0, + backend="pil", ): - trans = [transforms.RandomResizedCrop(crop_size, interpolation=interpolation)] + trans = [] + backend = backend.lower() + if backend == "tensor": + trans.append(transforms.PILToTensor()) + elif backend != "pil": + raise ValueError(f"backend can be 'tensor' or 'pil', but got {backend}") + + trans.append(transforms.RandomResizedCrop(crop_size, interpolation=interpolation, antialias=True)) if hflip_prob > 0: trans.append(transforms.RandomHorizontalFlip(hflip_prob)) if auto_augment_policy is not None: @@ -30,9 +38,12 @@ def __init__( else: aa_policy = autoaugment.AutoAugmentPolicy(auto_augment_policy) trans.append(autoaugment.AutoAugment(policy=aa_policy, interpolation=interpolation)) + + if backend == "pil": + trans.append(transforms.PILToTensor()) + trans.extend( [ - transforms.PILToTensor(), transforms.ConvertImageDtype(torch.float), transforms.Normalize(mean=mean, std=std), ] @@ -55,17 +66,30 @@ def __init__( mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), interpolation=InterpolationMode.BILINEAR, + backend="pil", ): + trans = [] - self.transforms = transforms.Compose( - [ - transforms.Resize(resize_size, interpolation=interpolation), - transforms.CenterCrop(crop_size), - transforms.PILToTensor(), - transforms.ConvertImageDtype(torch.float), - transforms.Normalize(mean=mean, std=std), - ] - ) + backend = backend.lower() + if backend == "tensor": + trans.append(transforms.PILToTensor()) + else: + raise ValueError(f"backend can be 'tensor' or 'pil', but got {backend}") + + trans += [ + transforms.Resize(resize_size, interpolation=interpolation, antialias=True), + transforms.CenterCrop(crop_size), + ] + + if backend == "pil": + trans.append(transforms.PILToTensor()) + + trans += [ + transforms.ConvertImageDtype(torch.float), + transforms.Normalize(mean=mean, std=std), + ] + + self.transforms = transforms.Compose(trans) def __call__(self, img): return self.transforms(img) diff --git a/references/classification/train.py b/references/classification/train.py index 10ba22bce03..0c1a301453d 100644 --- a/references/classification/train.py +++ b/references/classification/train.py @@ -7,6 +7,7 @@ import torch import torch.utils.data import torchvision +import torchvision.transforms import transforms import utils from sampler import RASampler @@ -143,6 +144,7 @@ def load_data(traindir, valdir, args): random_erase_prob=random_erase_prob, ra_magnitude=ra_magnitude, augmix_severity=augmix_severity, + backend=args.backend, ), ) if args.cache_dataset: @@ -160,10 +162,16 @@ def load_data(traindir, valdir, args): else: if args.weights and args.test_only: weights = torchvision.models.get_weight(args.weights) - preprocessing = weights.transforms() + preprocessing = weights.transforms(antialias=True) + if args.backend == "tensor": + preprocessing = torchvision.transforms.Compose([torchvision.transforms.PILToTensor(), preprocessing]) + else: preprocessing = presets.ClassificationPresetEval( - crop_size=val_crop_size, resize_size=val_resize_size, interpolation=interpolation + crop_size=val_crop_size, + resize_size=val_resize_size, + interpolation=interpolation, + backend=args.backend, ) dataset_test = torchvision.datasets.ImageFolder( @@ -507,6 +515,7 @@ def get_args_parser(add_help=True): "--ra-reps", default=3, type=int, help="number of repetitions for Repeated Augmentation (default: 3)" ) parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load") + parser.add_argument("--backend", default="PIL", type=str.lower, help="PIL or tensor - case insensitive") return parser