@@ -507,7 +507,7 @@ def rotate(
507507
508508
509509def pad_image_tensor (
510- img : torch .Tensor , padding : List [int ], fill : int = 0 , padding_mode : str = "constant"
510+ img : torch .Tensor , padding : List [int ], fill : Union [ int , float ] = 0 , padding_mode : str = "constant"
511511) -> torch .Tensor :
512512 num_masks , height , width = img .shape [- 3 :]
513513 extra_dims = img .shape [:- 3 ]
@@ -522,8 +522,11 @@ def pad_image_tensor(
522522
523523# TODO: This should be removed once pytorch pad supports non-scalar padding values
524524def _pad_with_vector_fill (
525- img : torch .Tensor , padding : List [int ], fill : Union [float , List [float ]] = 0.0 , padding_mode : str = "constant"
526- ):
525+ img : torch .Tensor ,
526+ padding : List [int ],
527+ fill : Sequence [float ] = [0.0 ],
528+ padding_mode : str = "constant" ,
529+ ) -> torch .Tensor :
527530 if padding_mode != "constant" :
528531 raise ValueError (f"Padding mode '{ padding_mode } ' is not supported if fill is not scalar" )
529532
@@ -573,7 +576,7 @@ def pad_bounding_box(
573576
574577
575578def pad (
576- inpt : Any , padding : List [int ], fill : Union [float , Sequence [float ]] = 0.0 , padding_mode : str = "constant"
579+ inpt : Any , padding : List [int ], fill : Union [int , float , Sequence [float ]] = 0.0 , padding_mode : str = "constant"
577580) -> Any :
578581 if isinstance (inpt , features ._Feature ):
579582 return inpt .pad (padding , fill = fill , padding_mode = padding_mode )
0 commit comments