11import torch
2- from torchvision .transforms import autoaugment , transforms
32from torchvision .transforms .functional import InterpolationMode
43
54
5+ def get_module (use_v2 ):
6+ # We need a protected import to avoid the V2 warning in case just V1 is used
7+ if use_v2 :
8+ import torchvision .transforms .v2
9+
10+ return torchvision .transforms .v2
11+ else :
12+ import torchvision .transforms
13+
14+ return torchvision .transforms
15+
16+
617class ClassificationPresetTrain :
718 def __init__ (
819 self ,
@@ -17,41 +28,44 @@ def __init__(
1728 augmix_severity = 3 ,
1829 random_erase_prob = 0.0 ,
1930 backend = "pil" ,
31+ use_v2 = False ,
2032 ):
21- trans = []
33+ module = get_module (use_v2 )
34+
35+ transforms = []
2236 backend = backend .lower ()
2337 if backend == "tensor" :
24- trans .append (transforms .PILToTensor ())
38+ transforms .append (module .PILToTensor ())
2539 elif backend != "pil" :
2640 raise ValueError (f"backend can be 'tensor' or 'pil', but got { backend } " )
2741
28- trans .append (transforms .RandomResizedCrop (crop_size , interpolation = interpolation , antialias = True ))
42+ transforms .append (module .RandomResizedCrop (crop_size , interpolation = interpolation , antialias = True ))
2943 if hflip_prob > 0 :
30- trans .append (transforms .RandomHorizontalFlip (hflip_prob ))
44+ transforms .append (module .RandomHorizontalFlip (hflip_prob ))
3145 if auto_augment_policy is not None :
3246 if auto_augment_policy == "ra" :
33- trans .append (autoaugment .RandAugment (interpolation = interpolation , magnitude = ra_magnitude ))
47+ transforms .append (module .RandAugment (interpolation = interpolation , magnitude = ra_magnitude ))
3448 elif auto_augment_policy == "ta_wide" :
35- trans .append (autoaugment .TrivialAugmentWide (interpolation = interpolation ))
49+ transforms .append (module .TrivialAugmentWide (interpolation = interpolation ))
3650 elif auto_augment_policy == "augmix" :
37- trans .append (autoaugment .AugMix (interpolation = interpolation , severity = augmix_severity ))
51+ transforms .append (module .AugMix (interpolation = interpolation , severity = augmix_severity ))
3852 else :
39- aa_policy = autoaugment .AutoAugmentPolicy (auto_augment_policy )
40- trans .append (autoaugment .AutoAugment (policy = aa_policy , interpolation = interpolation ))
53+ aa_policy = module .AutoAugmentPolicy (auto_augment_policy )
54+ transforms .append (module .AutoAugment (policy = aa_policy , interpolation = interpolation ))
4155
4256 if backend == "pil" :
43- trans .append (transforms .PILToTensor ())
57+ transforms .append (module .PILToTensor ())
4458
45- trans .extend (
59+ transforms .extend (
4660 [
47- transforms .ConvertImageDtype (torch .float ),
48- transforms .Normalize (mean = mean , std = std ),
61+ module .ConvertImageDtype (torch .float ),
62+ module .Normalize (mean = mean , std = std ),
4963 ]
5064 )
5165 if random_erase_prob > 0 :
52- trans .append (transforms .RandomErasing (p = random_erase_prob ))
66+ transforms .append (module .RandomErasing (p = random_erase_prob ))
5367
54- self .transforms = transforms .Compose (trans )
68+ self .transforms = module .Compose (transforms )
5569
5670 def __call__ (self , img ):
5771 return self .transforms (img )
@@ -67,28 +81,30 @@ def __init__(
6781 std = (0.229 , 0.224 , 0.225 ),
6882 interpolation = InterpolationMode .BILINEAR ,
6983 backend = "pil" ,
84+ use_v2 = False ,
7085 ):
71- trans = []
86+ module = get_module (use_v2 )
87+ transforms = []
7288 backend = backend .lower ()
7389 if backend == "tensor" :
74- trans .append (transforms .PILToTensor ())
90+ transforms .append (module .PILToTensor ())
7591 elif backend != "pil" :
7692 raise ValueError (f"backend can be 'tensor' or 'pil', but got { backend } " )
7793
78- trans += [
79- transforms .Resize (resize_size , interpolation = interpolation , antialias = True ),
80- transforms .CenterCrop (crop_size ),
94+ transforms += [
95+ module .Resize (resize_size , interpolation = interpolation , antialias = True ),
96+ module .CenterCrop (crop_size ),
8197 ]
8298
8399 if backend == "pil" :
84- trans .append (transforms .PILToTensor ())
100+ transforms .append (module .PILToTensor ())
85101
86- trans += [
87- transforms .ConvertImageDtype (torch .float ),
88- transforms .Normalize (mean = mean , std = std ),
102+ transforms += [
103+ module .ConvertImageDtype (torch .float ),
104+ module .Normalize (mean = mean , std = std ),
89105 ]
90106
91- self .transforms = transforms .Compose (trans )
107+ self .transforms = module .Compose (transforms )
92108
93109 def __call__ (self , img ):
94110 return self .transforms (img )
0 commit comments