@@ -503,10 +503,45 @@ def rotate(
503503 return inpt
504504
505505
506- pad_image_tensor = _FT .pad
507506pad_image_pil = _FP .pad
508507
509508
509+ def pad_image_tensor (
510+ img : torch .Tensor , padding : List [int ], fill : int = 0 , padding_mode : str = "constant"
511+ ) -> torch .Tensor :
512+ num_masks , height , width = img .shape [- 3 :]
513+ extra_dims = img .shape [:- 3 ]
514+
515+ padded_image = _FT .pad (
516+ img = img .view (- 1 , num_masks , height , width ), padding = padding , fill = fill , padding_mode = padding_mode
517+ )
518+
519+ new_height , new_width = padded_image .shape [- 2 :]
520+ return padded_image .view (extra_dims + (num_masks , new_height , new_width ))
521+
522+
523+ # TODO: This should be removed once pytorch pad supports non-scalar padding values
524+ def _pad_with_vector_fill (
525+ img : torch .Tensor , padding : List [int ], fill : Union [float , List [float ]] = 0.0 , padding_mode : str = "constant"
526+ ):
527+ if padding_mode != "constant" :
528+ raise ValueError (f"Padding mode '{ padding_mode } ' is not supported if fill is not scalar" )
529+
530+ output = pad_image_tensor (img , padding , fill = 0 , padding_mode = "constant" )
531+ left , top , right , bottom = padding
532+ fill = torch .tensor (fill , dtype = img .dtype , device = img .device ).view (- 1 , 1 , 1 )
533+
534+ if top > 0 :
535+ output [..., :top , :] = fill
536+ if left > 0 :
537+ output [..., :, :left ] = fill
538+ if bottom > 0 :
539+ output [..., - bottom :, :] = fill
540+ if right > 0 :
541+ output [..., :, - right :] = fill
542+ return output
543+
544+
510545def pad_segmentation_mask (
511546 segmentation_mask : torch .Tensor , padding : List [int ], padding_mode : str = "constant"
512547) -> torch .Tensor :
@@ -537,13 +572,19 @@ def pad_bounding_box(
537572 return bounding_box
538573
539574
540- def pad (inpt : Any , padding : List [int ], fill : int = 0 , padding_mode : str = "constant" ) -> Any :
575+ def pad (
576+ inpt : Any , padding : List [int ], fill : Union [float , Sequence [float ]] = 0.0 , padding_mode : str = "constant"
577+ ) -> Any :
541578 if isinstance (inpt , features ._Feature ):
542579 return inpt .pad (padding , fill = fill , padding_mode = padding_mode )
543580 elif isinstance (inpt , PIL .Image .Image ):
544581 return pad_image_pil (inpt , padding , fill = fill , padding_mode = padding_mode )
545582 elif isinstance (inpt , torch .Tensor ):
546- return pad_image_tensor (inpt , padding , fill = fill , padding_mode = padding_mode )
583+ # PyTorch's pad supports only scalars on fill. So we need to overwrite the colour
584+ if isinstance (fill , (int , float )):
585+ return pad_image_tensor (inpt , padding , fill = fill , padding_mode = padding_mode )
586+ else :
587+ return _pad_with_vector_fill (inpt , padding , fill = fill , padding_mode = padding_mode )
547588 else :
548589 return inpt
549590
0 commit comments