@@ -84,30 +84,6 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
8484 return output
8585
8686
87- class _RandomChannelShuffle (Transform ):
88- def _get_params (self , sample : Any ) -> Dict [str , Any ]:
89- image = query_image (sample )
90- num_channels , _ , _ = get_image_dimensions (image )
91- return dict (permutation = torch .randperm (num_channels ))
92-
93- def _transform (self , inpt : Any , params : Dict [str , Any ]) -> Any :
94- if not (isinstance (inpt , (features .Image , PIL .Image .Image )) or is_simple_tensor (inpt )):
95- return inpt
96-
97- image = inpt
98- if isinstance (inpt , PIL .Image .Image ):
99- image = _F .pil_to_tensor (image )
100-
101- output = image [..., params ["permutation" ], :, :]
102-
103- if isinstance (inpt , features .Image ):
104- output = features .Image .new_like (inpt , output , color_space = features .ColorSpace .OTHER )
105- elif isinstance (inpt , PIL .Image .Image ):
106- output = _F .to_pil_image (output )
107-
108- return output
109-
110-
11187class RandomPhotometricDistort (Transform ):
11288 def __init__ (
11389 self ,
@@ -118,35 +94,62 @@ def __init__(
11894 p : float = 0.5 ,
11995 ):
12096 super ().__init__ ()
121- self ._brightness = ColorJitter (brightness = brightness )
122- self ._contrast = ColorJitter (contrast = contrast )
123- self ._hue = ColorJitter (hue = hue )
124- self ._saturation = ColorJitter (saturation = saturation )
125- self ._channel_shuffle = _RandomChannelShuffle ()
97+ self .brightness = brightness
98+ self .contrast = contrast
99+ self .hue = hue
100+ self .saturation = saturation
126101 self .p = p
127102
128103 def _get_params (self , sample : Any ) -> Dict [str , Any ]:
104+ image = query_image (sample )
105+ num_channels , _ , _ = get_image_dimensions (image )
129106 return dict (
130107 zip (
131- ["brightness" , "contrast1" , "saturation" , "hue" , "contrast2" , "channel_shuffle" ],
108+ ["brightness" , "contrast1" , "saturation" , "hue" , "contrast2" ],
132109 torch .rand (6 ) < self .p ,
133110 ),
134111 contrast_before = torch .rand (()) < 0.5 ,
112+ channel_permutation = torch .randperm (num_channels ) if torch .rand (()) < self .p else None ,
135113 )
136114
115+ def _permute_channels (self , inpt : Any , * , permutation : torch .Tensor ) -> Any :
116+ if not (isinstance (inpt , (features .Image , PIL .Image .Image )) or is_simple_tensor (inpt )):
117+ return inpt
118+
119+ image = inpt
120+ if isinstance (inpt , PIL .Image .Image ):
121+ image = _F .pil_to_tensor (image )
122+
123+ output = image [..., permutation , :, :]
124+
125+ if isinstance (inpt , features .Image ):
126+ output = features .Image .new_like (inpt , output , color_space = features .ColorSpace .OTHER )
127+ elif isinstance (inpt , PIL .Image .Image ):
128+ output = _F .to_pil_image (output )
129+
130+ return output
131+
137132 def _transform (self , inpt : Any , params : Dict [str , Any ]) -> Any :
138133 if params ["brightness" ]:
139- inpt = self ._brightness (inpt )
134+ inpt = F .adjust_brightness (
135+ inpt , brightness_factor = ColorJitter ._generate_value (self .brightness [0 ], self .brightness [1 ])
136+ )
140137 if params ["contrast1" ] and params ["contrast_before" ]:
141- inpt = self . _contrast ( inpt )
142- if params [ "saturation" ]:
143- inpt = self . _saturation ( inpt )
138+ inpt = F . adjust_contrast (
139+ inpt , contrast_factor = ColorJitter . _generate_value ( self . contrast [ 0 ], self . contrast [ 1 ])
140+ )
144141 if params ["saturation" ]:
145- inpt = self ._saturation (inpt )
142+ inpt = F .adjust_saturation (
143+ inpt , saturation_factor = ColorJitter ._generate_value (self .saturation [0 ], self .saturation [1 ])
144+ )
145+ if params ["hue" ]:
146+ inpt = F .adjust_hue (inpt , hue_factor = ColorJitter ._generate_value (self .hue [0 ], self .hue [1 ]))
146147 if params ["contrast2" ] and not params ["contrast_before" ]:
147- inpt = self ._contrast (inpt )
148- if params ["channel_shuffle" ]:
149- inpt = self ._channel_shuffle (inpt )
148+ inpt = F .adjust_contrast (
149+ inpt , contrast_factor = ColorJitter ._generate_value (self .contrast [0 ], self .contrast [1 ])
150+ )
151+ if params ["channel_permutation" ]:
152+ inpt = self ._permute_channels (inpt , permutation = params ["channel_permutation" ])
150153 return inpt
151154
152155
0 commit comments