@@ -465,11 +465,15 @@ def load(self, device):
465465class ImageLoader (TensorLoader ):
466466 spatial_size : Tuple [int , int ] = dataclasses .field (init = False )
467467 num_channels : int = dataclasses .field (init = False )
468+ memory_format : torch .memory_format = torch .contiguous_format
468469
469470 def __post_init__ (self ):
470471 self .spatial_size = self .shape [- 2 :]
471472 self .num_channels = self .shape [- 3 ]
472473
474+ def load (self , device ):
475+ return self .fn (self .shape , self .dtype , device , memory_format = self .memory_format )
476+
473477
474478NUM_CHANNELS_MAP = {
475479 "GRAY" : 1 ,
@@ -493,18 +497,21 @@ def make_image_loader(
493497 extra_dims = (),
494498 dtype = torch .float32 ,
495499 constant_alpha = True ,
500+ memory_format = torch .contiguous_format ,
496501):
497502 size = _parse_spatial_size (size )
498503 num_channels = get_num_channels (color_space )
499504
500- def fn (shape , dtype , device ):
505+ def fn (shape , dtype , device , memory_format ):
501506 max_value = get_max_value (dtype )
502- data = torch .testing .make_tensor (shape , low = 0 , high = max_value , dtype = dtype , device = device )
507+ data = torch .testing .make_tensor (
508+ shape , low = 0 , high = max_value , dtype = dtype , device = device , memory_format = memory_format
509+ )
503510 if color_space in {"GRAY_ALPHA" , "RGBA" } and constant_alpha :
504511 data [..., - 1 , :, :] = max_value
505512 return datapoints .Image (data )
506513
507- return ImageLoader (fn , shape = (* extra_dims , num_channels , * size ), dtype = dtype )
514+ return ImageLoader (fn , shape = (* extra_dims , num_channels , * size ), dtype = dtype , memory_format = memory_format )
508515
509516
510517make_image = from_loader (make_image_loader )
@@ -530,11 +537,13 @@ def make_image_loaders(
530537make_images = from_loaders (make_image_loaders )
531538
532539
533- def make_image_loader_for_interpolation (size = "random" , * , color_space = "RGB" , dtype = torch .uint8 ):
540+ def make_image_loader_for_interpolation (
541+ size = "random" , * , color_space = "RGB" , dtype = torch .uint8 , memory_format = torch .contiguous_format
542+ ):
534543 size = _parse_spatial_size (size )
535544 num_channels = get_num_channels (color_space )
536545
537- def fn (shape , dtype , device ):
546+ def fn (shape , dtype , device , memory_format ):
538547 height , width = shape [- 2 :]
539548
540549 image_pil = (
@@ -550,19 +559,25 @@ def fn(shape, dtype, device):
550559 )
551560 )
552561
553- image_tensor = convert_dtype_image_tensor (to_image_tensor (image_pil ).to (device = device ), dtype = dtype )
562+ image_tensor = to_image_tensor (image_pil )
563+ if memory_format == torch .contiguous_format :
564+ image_tensor = image_tensor .to (device = device , memory_format = memory_format , copy = True )
565+ else :
566+ image_tensor = image_tensor .to (device = device )
567+ image_tensor = convert_dtype_image_tensor (image_tensor , dtype = dtype )
554568
555569 return datapoints .Image (image_tensor )
556570
557- return ImageLoader (fn , shape = (num_channels , * size ), dtype = dtype )
571+ return ImageLoader (fn , shape = (num_channels , * size ), dtype = dtype , memory_format = memory_format )
558572
559573
560574def make_image_loaders_for_interpolation (
561575 sizes = ((233 , 147 ),),
562576 color_spaces = ("RGB" ,),
563577 dtypes = (torch .uint8 ,),
578+ memory_formats = (torch .contiguous_format , torch .channels_last ),
564579):
565- for params in combinations_grid (size = sizes , color_space = color_spaces , dtype = dtypes ):
580+ for params in combinations_grid (size = sizes , color_space = color_spaces , dtype = dtypes , memory_format = memory_formats ):
566581 yield make_image_loader_for_interpolation (** params )
567582
568583
@@ -744,8 +759,10 @@ def make_video_loader(
744759 size = _parse_spatial_size (size )
745760 num_frames = int (torch .randint (1 , 5 , ())) if num_frames == "random" else num_frames
746761
747- def fn (shape , dtype , device ):
748- video = make_image (size = shape [- 2 :], extra_dims = shape [:- 3 ], dtype = dtype , device = device )
762+ def fn (shape , dtype , device , memory_format ):
763+ video = make_image (
764+ size = shape [- 2 :], extra_dims = shape [:- 3 ], dtype = dtype , device = device , memory_format = memory_format
765+ )
749766 return datapoints .Video (video )
750767
751768 return VideoLoader (fn , shape = (* extra_dims , num_frames , get_num_channels (color_space ), * size ), dtype = dtype )
0 commit comments