2424
2525
2626class FlowDataset (ABC , VisionDataset ):
27- # Some datasets like Kitti have a built-in valid mask , indicating which flow values are valid
28- # For those we return (img1, img2, flow, valid ), and for the rest we return (img1, img2, flow),
29- # and it's up to whatever consumes the dataset to decide what `valid` should be.
27+ # Some datasets like Kitti have a built-in valid_flow_mask , indicating which flow values are valid
28+ # For those we return (img1, img2, flow, valid_flow_mask ), and for the rest we return (img1, img2, flow),
29+ # and it's up to whatever consumes the dataset to decide what valid_flow_mask should be.
3030 _has_builtin_flow_mask = False
3131
3232 def __init__ (self , root , transforms = None ):
@@ -38,11 +38,14 @@ def __init__(self, root, transforms=None):
3838 self ._image_list = []
3939
4040 def _read_img (self , file_name ):
41- return Image .open (file_name )
41+ img = Image .open (file_name )
42+ if img .mode != "RGB" :
43+ img = img .convert ("RGB" )
44+ return img
4245
4346 @abstractmethod
4447 def _read_flow (self , file_name ):
45- # Return the flow or a tuple with the flow and the valid mask if _has_builtin_flow_mask is True
48+ # Return the flow or a tuple with the flow and the valid_flow_mask if _has_builtin_flow_mask is True
4649 pass
4750
4851 def __getitem__ (self , index ):
@@ -53,23 +56,27 @@ def __getitem__(self, index):
5356 if self ._flow_list : # it will be empty for some dataset when split="test"
5457 flow = self ._read_flow (self ._flow_list [index ])
5558 if self ._has_builtin_flow_mask :
56- flow , valid = flow
59+ flow , valid_flow_mask = flow
5760 else :
58- valid = None
61+ valid_flow_mask = None
5962 else :
60- flow = valid = None
63+ flow = valid_flow_mask = None
6164
6265 if self .transforms is not None :
63- img1 , img2 , flow , valid = self .transforms (img1 , img2 , flow , valid )
66+ img1 , img2 , flow , valid_flow_mask = self .transforms (img1 , img2 , flow , valid_flow_mask )
6467
65- if self ._has_builtin_flow_mask :
66- return img1 , img2 , flow , valid
68+ if self ._has_builtin_flow_mask or valid_flow_mask is not None :
69+ # The `or valid_flow_mask is not None` part is here because the mask can be generated within a transform
70+ return img1 , img2 , flow , valid_flow_mask
6771 else :
6872 return img1 , img2 , flow
6973
7074 def __len__ (self ):
7175 return len (self ._image_list )
7276
77+ def __rmul__ (self , v ):
78+ return torch .utils .data .ConcatDataset ([self ] * v )
79+
7380
7481class Sintel (FlowDataset ):
7582 """`Sintel <http://sintel.is.tue.mpg.de/>`_ Dataset for optical flow.
@@ -107,8 +114,8 @@ class Sintel(FlowDataset):
107114 pass_name (string, optional): The pass to use, either "clean" (default), "final", or "both". See link above for
108115 details on the different passes.
109116 transforms (callable, optional): A function/transform that takes in
110- ``img1, img2, flow, valid `` and returns a transformed version.
111- ``valid `` is expected for consistency with other datasets which
117+ ``img1, img2, flow, valid_flow_mask `` and returns a transformed version.
118+ ``valid_flow_mask `` is expected for consistency with other datasets which
112119 return a built-in valid mask, such as :class:`~torchvision.datasets.KittiFlow`.
113120 """
114121
@@ -140,9 +147,11 @@ def __getitem__(self, index):
140147 index(int): The index of the example to retrieve
141148
142149 Returns:
143- tuple: If ``split="train"`` a 3-tuple with ``(img1, img2, flow)``.
144- The flow is a numpy array of shape (2, H, W) and the images are PIL images. If `split="test"`, a
145- 3-tuple with ``(img1, img2, None)`` is returned.
150+ tuple: A 3-tuple with ``(img1, img2, flow)``.
151+ The flow is a numpy array of shape (2, H, W) and the images are PIL images.
152+ ``flow`` is None if ``split="test"``.
153+ If a valid flow mask is generated within the ``transforms`` parameter,
154+ a 4-tuple with ``(img1, img2, flow, valid_flow_mask)`` is returned.
146155 """
147156 return super ().__getitem__ (index )
148157
@@ -167,7 +176,7 @@ class KittiFlow(FlowDataset):
167176 root (string): Root directory of the KittiFlow Dataset.
168177 split (string, optional): The dataset split, either "train" (default) or "test"
169178 transforms (callable, optional): A function/transform that takes in
170- ``img1, img2, flow, valid `` and returns a transformed version.
179+ ``img1, img2, flow, valid_flow_mask `` and returns a transformed version.
171180 """
172181
173182 _has_builtin_flow_mask = True
@@ -199,11 +208,11 @@ def __getitem__(self, index):
199208 index(int): The index of the example to retrieve
200209
201210 Returns:
202- tuple: If ``split="train"`` a 4-tuple with ``(img1, img2, flow,
203- valid)`` where ``valid `` is a numpy boolean mask of shape (H, W)
211+ tuple: A 4-tuple with ``(img1, img2, flow, valid_flow_mask)``
212+ where ``valid_flow_mask `` is a numpy boolean mask of shape (H, W)
204213 indicating which flow values are valid. The flow is a numpy array of
205- shape (2, H, W) and the images are PIL images. If `split="test"`, a
206- 4-tuple with ``(img1, img2, None, None)`` is returned .
214+ shape (2, H, W) and the images are PIL images. ``flow`` and ``valid_flow_mask`` are None if
215+ ``split="test"`` .
207216 """
208217 return super ().__getitem__ (index )
209218
@@ -232,8 +241,8 @@ class FlyingChairs(FlowDataset):
232241 root (string): Root directory of the FlyingChairs Dataset.
233242 split (string, optional): The dataset split, either "train" (default) or "val"
234243 transforms (callable, optional): A function/transform that takes in
235- ``img1, img2, flow, valid `` and returns a transformed version.
236- ``valid `` is expected for consistency with other datasets which
244+ ``img1, img2, flow, valid_flow_mask `` and returns a transformed version.
245+ ``valid_flow_mask `` is expected for consistency with other datasets which
237246 return a built-in valid mask, such as :class:`~torchvision.datasets.KittiFlow`.
238247 """
239248
@@ -269,6 +278,9 @@ def __getitem__(self, index):
269278 Returns:
270279 tuple: A 3-tuple with ``(img1, img2, flow)``.
271280 The flow is a numpy array of shape (2, H, W) and the images are PIL images.
281+ ``flow`` is None if ``split="val"``.
282+ If a valid flow mask is generated within the ``transforms`` parameter,
283+ a 4-tuple with ``(img1, img2, flow, valid_flow_mask)`` is returned.
272284 """
273285 return super ().__getitem__ (index )
274286
@@ -300,8 +312,8 @@ class FlyingThings3D(FlowDataset):
300312 details on the different passes.
301313 camera (string, optional): Which camera to return images from. Can be either "left" (default) or "right" or "both".
302314 transforms (callable, optional): A function/transform that takes in
303- ``img1, img2, flow, valid `` and returns a transformed version.
304- ``valid `` is expected for consistency with other datasets which
315+ ``img1, img2, flow, valid_flow_mask `` and returns a transformed version.
316+ ``valid_flow_mask `` is expected for consistency with other datasets which
305317 return a built-in valid mask, such as :class:`~torchvision.datasets.KittiFlow`.
306318 """
307319
@@ -357,6 +369,9 @@ def __getitem__(self, index):
357369 Returns:
358370 tuple: A 3-tuple with ``(img1, img2, flow)``.
359371 The flow is a numpy array of shape (2, H, W) and the images are PIL images.
372+ ``flow`` is None if ``split="test"``.
373+ If a valid flow mask is generated within the ``transforms`` parameter,
374+ a 4-tuple with ``(img1, img2, flow, valid_flow_mask)`` is returned.
360375 """
361376 return super ().__getitem__ (index )
362377
@@ -382,7 +397,7 @@ class HD1K(FlowDataset):
382397 root (string): Root directory of the HD1K Dataset.
383398 split (string, optional): The dataset split, either "train" (default) or "test"
384399 transforms (callable, optional): A function/transform that takes in
385- ``img1, img2, flow, valid `` and returns a transformed version.
400+ ``img1, img2, flow, valid_flow_mask `` and returns a transformed version.
386401 """
387402
388403 _has_builtin_flow_mask = True
@@ -422,11 +437,11 @@ def __getitem__(self, index):
422437 index(int): The index of the example to retrieve
423438
424439 Returns:
425- tuple: If ``split="train"`` a 4-tuple with ``(img1, img2, flow,
426- valid)`` where ``valid`` is a numpy boolean mask of shape (H, W)
440+ tuple: A 4-tuple with ``(img1, img2, flow, valid_flow_mask)`` where ``valid_flow_mask``
441+ is a numpy boolean mask of shape (H, W)
427442 indicating which flow values are valid. The flow is a numpy array of
428- shape (2, H, W) and the images are PIL images. If `split="test"`, a
429- 4-tuple with ``(img1, img2, None, None)`` is returned .
443+ shape (2, H, W) and the images are PIL images. ``flow`` and ``valid_flow_mask`` are None if
444+ ``split="test"`` .
430445 """
431446 return super ().__getitem__ (index )
432447
@@ -451,11 +466,12 @@ def _read_flo(file_name):
451466def _read_16bits_png_with_flow_and_valid_mask (file_name ):
452467
453468 flow_and_valid = _read_png_16 (file_name ).to (torch .float32 )
454- flow , valid = flow_and_valid [:2 , :, :], flow_and_valid [2 , :, :]
469+ flow , valid_flow_mask = flow_and_valid [:2 , :, :], flow_and_valid [2 , :, :]
455470 flow = (flow - 2 ** 15 ) / 64 # This conversion is explained somewhere on the kitti archive
471+ valid_flow_mask = valid_flow_mask .bool ()
456472
457473 # For consistency with other datasets, we convert to numpy
458- return flow .numpy (), valid .numpy ()
474+ return flow .numpy (), valid_flow_mask .numpy ()
459475
460476
461477def _read_pfm (file_name ):
0 commit comments