1+ import itertools
12import os
3+ import re
24from abc import ABC , abstractmethod
35from glob import glob
46from pathlib import Path
1517__all__ = (
1618 "KittiFlow" ,
1719 "Sintel" ,
20+ "FlyingThings3D" ,
1821 "FlyingChairs" ,
1922)
2023
@@ -271,6 +274,94 @@ def _read_flow(self, file_name):
271274 return _read_flo (file_name )
272275
273276
277+ class FlyingThings3D (FlowDataset ):
278+ """`FlyingThings3D <https://lmb.informatik.uni-freiburg.de/resources/datasets/SceneFlowDatasets.en.html>`_ dataset for optical flow.
279+
280+ The dataset is expected to have the following structure: ::
281+
282+ root
283+ FlyingThings3D
284+ frames_cleanpass
285+ TEST
286+ TRAIN
287+ frames_finalpass
288+ TEST
289+ TRAIN
290+ optical_flow
291+ TEST
292+ TRAIN
293+
294+ Args:
295+ root (string): Root directory of the intel FlyingThings3D Dataset.
296+ split (string, optional): The dataset split, either "train" (default) or "test"
297+ pass_name (string, optional): The pass to use, either "clean" (default) or "final" or "both". See link above for
298+ details on the different passes.
299+ camera (string, optional): Which camera to return images from. Can be either "left" (default) or "right" or "both".
300+ transforms (callable, optional): A function/transform that takes in
301+ ``img1, img2, flow, valid`` and returns a transformed version.
302+ ``valid`` is expected for consistency with other datasets which
303+ return a built-in valid mask, such as :class:`~torchvision.datasets.KittiFlow`.
304+ """
305+
306+ def __init__ (self , root , split = "train" , pass_name = "clean" , camera = "left" , transforms = None ):
307+ super ().__init__ (root = root , transforms = transforms )
308+
309+ verify_str_arg (split , "split" , valid_values = ("train" , "test" ))
310+ split = split .upper ()
311+
312+ verify_str_arg (pass_name , "pass_name" , valid_values = ("clean" , "final" , "both" ))
313+ passes = {
314+ "clean" : ["frames_cleanpass" ],
315+ "final" : ["frames_finalpass" ],
316+ "both" : ["frames_cleanpass" , "frames_finalpass" ],
317+ }[pass_name ]
318+
319+ verify_str_arg (camera , "camera" , valid_values = ("left" , "right" , "both" ))
320+ cameras = ["left" , "right" ] if camera == "both" else [camera ]
321+
322+ root = Path (root ) / "FlyingThings3D"
323+
324+ directions = ("into_future" , "into_past" )
325+ for pass_name , camera , direction in itertools .product (passes , cameras , directions ):
326+ image_dirs = sorted (glob (str (root / pass_name / split / "*/*" )))
327+ image_dirs = sorted ([Path (image_dir ) / camera for image_dir in image_dirs ])
328+
329+ flow_dirs = sorted (glob (str (root / "optical_flow" / split / "*/*" )))
330+ flow_dirs = sorted ([Path (flow_dir ) / direction / camera for flow_dir in flow_dirs ])
331+
332+ if not image_dirs or not flow_dirs :
333+ raise FileNotFoundError (
334+ "Could not find the FlyingThings3D flow images. "
335+ "Please make sure the directory structure is correct."
336+ )
337+
338+ for image_dir , flow_dir in zip (image_dirs , flow_dirs ):
339+ images = sorted (glob (str (image_dir / "*.png" )))
340+ flows = sorted (glob (str (flow_dir / "*.pfm" )))
341+ for i in range (len (flows ) - 1 ):
342+ if direction == "into_future" :
343+ self ._image_list += [[images [i ], images [i + 1 ]]]
344+ self ._flow_list += [flows [i ]]
345+ elif direction == "into_past" :
346+ self ._image_list += [[images [i + 1 ], images [i ]]]
347+ self ._flow_list += [flows [i + 1 ]]
348+
349+ def __getitem__ (self , index ):
350+ """Return example at given index.
351+
352+ Args:
353+ index(int): The index of the example to retrieve
354+
355+ Returns:
356+ tuple: A 3-tuple with ``(img1, img2, flow)``.
357+ The flow is a numpy array of shape (2, H, W) and the images are PIL images.
358+ """
359+ return super ().__getitem__ (index )
360+
361+ def _read_flow (self , file_name ):
362+ return _read_pfm (file_name )
363+
364+
274365def _read_flo (file_name ):
275366 """Read .flo file in Middlebury format"""
276367 # Code adapted from:
@@ -295,3 +386,31 @@ def _read_16bits_png_with_flow_and_valid_mask(file_name):
295386
296387 # For consistency with other datasets, we convert to numpy
297388 return flow .numpy (), valid .numpy ()
389+
390+
391+ def _read_pfm (file_name ):
392+ """Read flow in .pfm format"""
393+
394+ with open (file_name , "rb" ) as f :
395+ header = f .readline ().rstrip ()
396+ if header != b"PF" :
397+ raise ValueError ("Invalid PFM file" )
398+
399+ dim_match = re .match (rb"^(\d+)\s(\d+)\s$" , f .readline ())
400+ if not dim_match :
401+ raise Exception ("Malformed PFM header." )
402+ w , h = (int (dim ) for dim in dim_match .groups ())
403+
404+ scale = float (f .readline ().rstrip ())
405+ if scale < 0 : # little-endian
406+ endian = "<"
407+ scale = - scale
408+ else :
409+ endian = ">" # big-endian
410+
411+ data = np .fromfile (f , dtype = endian + "f" )
412+
413+ data = data .reshape (h , w , 3 ).transpose (2 , 0 , 1 )
414+ data = np .flip (data , axis = 1 ) # flip on h dimension
415+ data = data [:2 , :, :]
416+ return data .astype (np .float32 )
0 commit comments