@@ -285,3 +285,99 @@ def convert_color_space(
285285 return features .Video .wrap_like (inpt , output , color_space = color_space )
286286 else :
287287 return convert_color_space_image_pil (inpt , color_space )
288+
289+
290+ def _num_value_bits (dtype : torch .dtype ) -> int :
291+ if dtype == torch .uint8 :
292+ return 8
293+ elif dtype == torch .int8 :
294+ return 7
295+ elif dtype == torch .int16 :
296+ return 15
297+ elif dtype == torch .int32 :
298+ return 31
299+ elif dtype == torch .int64 :
300+ return 63
301+ else :
302+ raise TypeError (f"Number of value bits is only defined for integer dtypes, but got { dtype } ." )
303+
304+
305+ def convert_dtype_image_tensor (image : torch .Tensor , dtype : torch .dtype = torch .float ) -> torch .Tensor :
306+ if image .dtype == dtype :
307+ return image
308+
309+ float_input = image .is_floating_point ()
310+ if torch .jit .is_scripting ():
311+ # TODO: remove this branch as soon as `dtype.is_floating_point` is supported by JIT
312+ float_output = torch .tensor (0 , dtype = dtype ).is_floating_point ()
313+ else :
314+ float_output = dtype .is_floating_point
315+
316+ if float_input :
317+ # float to float
318+ if float_output :
319+ return image .to (dtype )
320+
321+ # float to int
322+ if (image .dtype == torch .float32 and dtype in (torch .int32 , torch .int64 )) or (
323+ image .dtype == torch .float64 and dtype == torch .int64
324+ ):
325+ raise RuntimeError (f"The conversion from { image .dtype } to { dtype } cannot be performed safely." )
326+
327+ # For data in the range `[0.0, 1.0]`, just multiplying by the maximum value of the integer range and converting
328+ # to the integer dtype is not sufficient. For example, `torch.rand(...).mul(255).to(torch.uint8)` will only
329+ # be `255` if the input is exactly `1.0`. See https://github.com/pytorch/vision/pull/2078#issuecomment-612045321
330+ # for a detailed analysis.
331+ # To mitigate this, we could round before we convert to the integer dtype, but this is an extra operation.
332+ # Instead, we can also multiply by the maximum value plus something close to `1`. See
333+ # https://github.com/pytorch/vision/pull/2078#issuecomment-613524965 for details.
334+ eps = 1e-3
335+ max_value = float (_FT ._max_value (dtype ))
336+ # We need to scale first since the conversion would otherwise turn the input range `[0.0, 1.0]` into the
337+ # discrete set `{0, 1}`.
338+ return image .mul (max_value + 1.0 - eps ).to (dtype )
339+ else :
340+ # int to float
341+ if float_output :
342+ return image .to (dtype ).div_ (_FT ._max_value (image .dtype ))
343+
344+ # int to int
345+ num_value_bits_input = _num_value_bits (image .dtype )
346+ num_value_bits_output = _num_value_bits (dtype )
347+
348+ if num_value_bits_input > num_value_bits_output :
349+ return image .bitwise_right_shift (num_value_bits_input - num_value_bits_output ).to (dtype )
350+ else :
351+ # The bitshift kernel is not vectorized
352+ # https://github.com/pytorch/pytorch/blob/703c19008df4700b6a522b0ae5c4b6d5ffc0906f/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp#L315-L322
353+ # This results in the multiplication actually being faster.
354+ # TODO: If the bitshift kernel is optimized in core, replace the computation below with
355+ # `image.to(dtype).bitwise_left_shift_(num_value_bits_output - num_value_bits_input)`
356+ max_value_input = float (_FT ._max_value (dtype ))
357+ max_value_output = float (_FT ._max_value (image .dtype ))
358+ factor = int ((max_value_input + 1 ) // (max_value_output + 1 ))
359+ return image .to (dtype ).mul_ (factor )
360+
361+
362+ # We changed the name to align it with the new naming scheme. Still, `convert_image_dtype` is
363+ # prevalent and well understood. Thus, we just alias it without deprecating the old name.
364+ convert_image_dtype = convert_dtype_image_tensor
365+
366+
367+ def convert_dtype_video (video : torch .Tensor , dtype : torch .dtype = torch .float ) -> torch .Tensor :
368+ return convert_dtype_image_tensor (video , dtype )
369+
370+
371+ def convert_dtype (
372+ inpt : Union [features .ImageTypeJIT , features .VideoTypeJIT ], dtype : torch .dtype = torch .float
373+ ) -> torch .Tensor :
374+ if isinstance (inpt , torch .Tensor ) and (
375+ torch .jit .is_scripting () or not isinstance (inpt , (features .Image , features .Video ))
376+ ):
377+ return convert_dtype_image_tensor (inpt , dtype )
378+ elif isinstance (inpt , features .Image ):
379+ output = convert_dtype_image_tensor (inpt .as_subclass (torch .Tensor ), dtype )
380+ return features .Image .wrap_like (inpt , output )
381+ else : # isinstance(inpt, features.Video):
382+ output = convert_dtype_video (inpt .as_subclass (torch .Tensor ), dtype )
383+ return features .Video .wrap_like (inpt , output )
0 commit comments