@@ -570,6 +570,13 @@ def _apply_grid_transform(img: torch.Tensor, grid: torch.Tensor, mode: str, fill
570570        # Apply same grid to a batch of images 
571571        grid  =  grid .expand (squashed_batch_size , - 1 , - 1 , - 1 )
572572
573+     if  fill  is  not None  and  not  isinstance (fill , (tuple , list )):
574+         fill  =  [float (fill )]
575+ 
576+     # filling with zeros is the default behavior and thus we can skip the extra fill handling 
577+     if  fill  is  not None  and  all (f  ==  0  for  f  in  fill ):
578+         fill  =  None 
579+ 
573580    # Append a dummy mask for customized fill colors, should be faster than grid_sample() twice 
574581    if  fill  is  not None :
575582        mask  =  torch .ones (
@@ -583,8 +590,7 @@ def _apply_grid_transform(img: torch.Tensor, grid: torch.Tensor, mode: str, fill
583590    if  fill  is  not None :
584591        float_img , mask  =  torch .tensor_split (float_img , indices = (- 1 ,), dim = - 3 )
585592        mask  =  mask .expand_as (float_img )
586-         fill_list  =  fill  if  isinstance (fill , (tuple , list )) else  [float (fill )]  # type: ignore[arg-type] 
587-         fill_img  =  torch .tensor (fill_list , dtype = float_img .dtype , device = float_img .device ).view (1 , - 1 , 1 , 1 )
593+         fill_img  =  torch .tensor (fill , dtype = float_img .dtype , device = float_img .device ).view (1 , - 1 , 1 , 1 )
588594        if  mode  ==  "nearest" :
589595            bool_mask  =  mask  <  0.5 
590596            float_img [bool_mask ] =  fill_img .expand_as (float_img )[bool_mask ]
0 commit comments