@@ -238,7 +238,6 @@ def load(self, device):
238238
239239@dataclasses .dataclass
240240class ImageLoader (TensorLoader ):
241- color_space : datapoints .ColorSpace
242241 spatial_size : Tuple [int , int ] = dataclasses .field (init = False )
243242 num_channels : int = dataclasses .field (init = False )
244243
@@ -248,10 +247,10 @@ def __post_init__(self):
248247
249248
250249NUM_CHANNELS_MAP = {
251- datapoints . ColorSpace . GRAY : 1 ,
252- datapoints . ColorSpace . GRAY_ALPHA : 2 ,
253- datapoints . ColorSpace . RGB : 3 ,
254- datapoints . ColorSpace . RGB_ALPHA : 4 ,
250+ " GRAY" : 1 ,
251+ " GRAY_ALPHA" : 2 ,
252+ " RGB" : 3 ,
253+ "RGBA" : 4 ,
255254}
256255
257256
@@ -265,7 +264,7 @@ def get_num_channels(color_space):
265264def make_image_loader (
266265 size = "random" ,
267266 * ,
268- color_space = datapoints . ColorSpace . RGB ,
267+ color_space = " RGB" ,
269268 extra_dims = (),
270269 dtype = torch .float32 ,
271270 constant_alpha = True ,
@@ -276,11 +275,11 @@ def make_image_loader(
276275 def fn (shape , dtype , device ):
277276 max_value = get_max_value (dtype )
278277 data = torch .testing .make_tensor (shape , low = 0 , high = max_value , dtype = dtype , device = device )
279- if color_space in {datapoints . ColorSpace . GRAY_ALPHA , datapoints . ColorSpace . RGB_ALPHA } and constant_alpha :
278+ if color_space in {" GRAY_ALPHA" , "RGBA" } and constant_alpha :
280279 data [..., - 1 , :, :] = max_value
281- return datapoints .Image (data , color_space = color_space )
280+ return datapoints .Image (data )
282281
283- return ImageLoader (fn , shape = (* extra_dims , num_channels , * size ), dtype = dtype , color_space = color_space )
282+ return ImageLoader (fn , shape = (* extra_dims , num_channels , * size ), dtype = dtype )
284283
285284
286285make_image = from_loader (make_image_loader )
@@ -290,10 +289,10 @@ def make_image_loaders(
290289 * ,
291290 sizes = DEFAULT_SPATIAL_SIZES ,
292291 color_spaces = (
293- datapoints . ColorSpace . GRAY ,
294- datapoints . ColorSpace . GRAY_ALPHA ,
295- datapoints . ColorSpace . RGB ,
296- datapoints . ColorSpace . RGB_ALPHA ,
292+ " GRAY" ,
293+ " GRAY_ALPHA" ,
294+ " RGB" ,
295+ "RGBA" ,
297296 ),
298297 extra_dims = DEFAULT_EXTRA_DIMS ,
299298 dtypes = (torch .float32 , torch .uint8 ),
@@ -306,7 +305,7 @@ def make_image_loaders(
306305make_images = from_loaders (make_image_loaders )
307306
308307
309- def make_image_loader_for_interpolation (size = "random" , * , color_space = datapoints . ColorSpace . RGB , dtype = torch .uint8 ):
308+ def make_image_loader_for_interpolation (size = "random" , * , color_space = " RGB" , dtype = torch .uint8 ):
310309 size = _parse_spatial_size (size )
311310 num_channels = get_num_channels (color_space )
312311
@@ -318,24 +317,24 @@ def fn(shape, dtype, device):
318317 .resize ((width , height ))
319318 .convert (
320319 {
321- datapoints . ColorSpace . GRAY : "L" ,
322- datapoints . ColorSpace . GRAY_ALPHA : "LA" ,
323- datapoints . ColorSpace . RGB : "RGB" ,
324- datapoints . ColorSpace . RGB_ALPHA : "RGBA" ,
320+ " GRAY" : "L" ,
321+ " GRAY_ALPHA" : "LA" ,
322+ " RGB" : "RGB" ,
323+ "RGBA" : "RGBA" ,
325324 }[color_space ]
326325 )
327326 )
328327
329328 image_tensor = convert_dtype_image_tensor (to_image_tensor (image_pil ).to (device = device ), dtype = dtype )
330329
331- return datapoints .Image (image_tensor , color_space = color_space )
330+ return datapoints .Image (image_tensor )
332331
333- return ImageLoader (fn , shape = (num_channels , * size ), dtype = dtype , color_space = color_space )
332+ return ImageLoader (fn , shape = (num_channels , * size ), dtype = dtype )
334333
335334
336335def make_image_loaders_for_interpolation (
337336 sizes = ((233 , 147 ),),
338- color_spaces = (datapoints . ColorSpace . RGB ,),
337+ color_spaces = (" RGB" ,),
339338 dtypes = (torch .uint8 ,),
340339):
341340 for params in combinations_grid (size = sizes , color_space = color_spaces , dtype = dtypes ):
@@ -583,7 +582,7 @@ class VideoLoader(ImageLoader):
583582def make_video_loader (
584583 size = "random" ,
585584 * ,
586- color_space = datapoints . ColorSpace . RGB ,
585+ color_space = " RGB" ,
587586 num_frames = "random" ,
588587 extra_dims = (),
589588 dtype = torch .uint8 ,
@@ -592,12 +591,10 @@ def make_video_loader(
592591 num_frames = int (torch .randint (1 , 5 , ())) if num_frames == "random" else num_frames
593592
594593 def fn (shape , dtype , device ):
595- video = make_image (size = shape [- 2 :], color_space = color_space , extra_dims = shape [:- 3 ], dtype = dtype , device = device )
596- return datapoints .Video (video , color_space = color_space )
594+ video = make_image (size = shape [- 2 :], extra_dims = shape [:- 3 ], dtype = dtype , device = device )
595+ return datapoints .Video (video )
597596
598- return VideoLoader (
599- fn , shape = (* extra_dims , num_frames , get_num_channels (color_space ), * size ), dtype = dtype , color_space = color_space
600- )
597+ return VideoLoader (fn , shape = (* extra_dims , num_frames , get_num_channels (color_space ), * size ), dtype = dtype )
601598
602599
603600make_video = from_loader (make_video_loader )
@@ -607,8 +604,8 @@ def make_video_loaders(
607604 * ,
608605 sizes = DEFAULT_SPATIAL_SIZES ,
609606 color_spaces = (
610- datapoints . ColorSpace . GRAY ,
611- datapoints . ColorSpace . RGB ,
607+ " GRAY" ,
608+ " RGB" ,
612609 ),
613610 num_frames = (1 , 0 , "random" ),
614611 extra_dims = DEFAULT_EXTRA_DIMS ,
0 commit comments