@@ -1569,31 +1569,35 @@ def reference_inputs_equalize_image_tensor():
15691569 # We are not using `make_image_loaders` here since that uniformly samples the values over the whole value range.
15701570 # Since the whole point of this kernel is to transform an arbitrary distribution of values into a uniform one,
15711571 # the information gain is low if we already provide something really close to the expected value.
1572- def make_uniform_band_image (shape , dtype , device , * , low_factor , high_factor ):
1572+ def make_uniform_band_image (shape , dtype , device , * , low_factor , high_factor , memory_format ):
15731573 if dtype .is_floating_point :
15741574 low = low_factor
15751575 high = high_factor
15761576 else :
15771577 max_value = torch .iinfo (dtype ).max
15781578 low = int (low_factor * max_value )
15791579 high = int (high_factor * max_value )
1580- return torch .testing .make_tensor (shape , dtype = dtype , device = device , low = low , high = high )
1580+ return torch .testing .make_tensor (shape , dtype = dtype , device = device , low = low , high = high ).to (
1581+ memory_format = memory_format , copy = True
1582+ )
15811583
1582- def make_beta_distributed_image (shape , dtype , device , * , alpha , beta ):
1584+ def make_beta_distributed_image (shape , dtype , device , * , alpha , beta , memory_format ):
15831585 image = torch .distributions .Beta (alpha , beta ).sample (shape )
15841586 if not dtype .is_floating_point :
15851587 image .mul_ (torch .iinfo (dtype ).max ).round_ ()
1586- return image .to (dtype = dtype , device = device )
1588+ return image .to (dtype = dtype , device = device , memory_format = memory_format , copy = True )
15871589
15881590 spatial_size = (256 , 256 )
15891591 for dtype , color_space , fn in itertools .product (
15901592 [torch .uint8 ],
15911593 ["GRAY" , "RGB" ],
15921594 [
1593- lambda shape , dtype , device : torch .zeros (shape , dtype = dtype , device = device ),
1594- lambda shape , dtype , device : torch .full (
1595- shape , 1.0 if dtype .is_floating_point else torch .iinfo (dtype ).max , dtype = dtype , device = device
1595+ lambda shape , dtype , device , memory_format : torch .zeros (shape , dtype = dtype , device = device ).to (
1596+ memory_format = memory_format , copy = True
15961597 ),
1598+ lambda shape , dtype , device , memory_format : torch .full (
1599+ shape , 1.0 if dtype .is_floating_point else torch .iinfo (dtype ).max , dtype = dtype , device = device
1600+ ).to (memory_format = memory_format , copy = True ),
15971601 * [
15981602 functools .partial (make_uniform_band_image , low_factor = low_factor , high_factor = high_factor )
15991603 for low_factor , high_factor in [
0 commit comments