@@ -250,6 +250,21 @@ def __post_init__(self):
250250 self .num_channels = self .shape [- 3 ]
251251
252252
253+ NUM_CHANNELS_MAP = {
254+ features .ColorSpace .GRAY : 1 ,
255+ features .ColorSpace .GRAY_ALPHA : 2 ,
256+ features .ColorSpace .RGB : 3 ,
257+ features .ColorSpace .RGB_ALPHA : 4 ,
258+ }
259+
260+
261+ def get_num_channels (color_space ):
262+ num_channels = NUM_CHANNELS_MAP .get (color_space )
263+ if not num_channels :
264+ raise pytest .UsageError (f"Can't determine the number of channels for color space { color_space } " )
265+ return num_channels
266+
267+
253268def make_image_loader (
254269 size = "random" ,
255270 * ,
@@ -259,16 +274,7 @@ def make_image_loader(
259274 constant_alpha = True ,
260275):
261276 size = _parse_image_size (size )
262-
263- try :
264- num_channels = {
265- features .ColorSpace .GRAY : 1 ,
266- features .ColorSpace .GRAY_ALPHA : 2 ,
267- features .ColorSpace .RGB : 3 ,
268- features .ColorSpace .RGB_ALPHA : 4 ,
269- }[color_space ]
270- except KeyError as error :
271- raise pytest .UsageError (f"Can't determine the number of channels for color space { color_space } " ) from error
277+ num_channels = get_num_channels (color_space )
272278
273279 def fn (shape , dtype , device ):
274280 max_value = get_max_value (dtype )
@@ -550,13 +556,15 @@ def make_video_loader(
550556 dtype = torch .uint8 ,
551557):
552558 size = _parse_image_size (size )
553- num_frames = int (torch .randint (1 , 4 , ())) if num_frames == "random" else num_frames
559+ num_frames = int (torch .randint (1 , 5 , ())) if num_frames == "random" else num_frames
554560
555561 def fn (shape , dtype , device ):
556562 video = make_image (size = shape [- 2 :], color_space = color_space , extra_dims = shape [:- 2 ], dtype = dtype , device = device )
557563 return features .Video (video , color_space = color_space )
558564
559- return VideoLoader (fn , shape = (* extra_dims , num_frames , * size ), dtype = dtype , color_space = color_space )
565+ return VideoLoader (
566+ fn , shape = (* extra_dims , num_frames , get_num_channels (color_space ), * size ), dtype = dtype , color_space = color_space
567+ )
560568
561569
562570make_video = from_loader (make_video_loader )
0 commit comments