@@ -1874,11 +1874,9 @@ def _inject_pairs(self, root, num_pairs, same):
18741874class SintelTestCase (datasets_utils .ImageDatasetTestCase ):
18751875 DATASET_CLASS = datasets .Sintel
18761876 ADDITIONAL_CONFIGS = datasets_utils .combinations_grid (split = ("train" , "test" ), pass_name = ("clean" , "final" ))
1877- # We patch the flow reader, because this would otherwise force us to generate fake (but readable) .flo files,
1878- # which is something we want to # avoid.
1879- _FAKE_FLOW = "Fake Flow"
1880- EXTRA_PATCHES = {unittest .mock .patch ("torchvision.datasets.Sintel._read_flow" , return_value = _FAKE_FLOW )}
1881- FEATURE_TYPES = (PIL .Image .Image , PIL .Image .Image , (type (_FAKE_FLOW ), type (None )))
1877+ FEATURE_TYPES = (PIL .Image .Image , PIL .Image .Image , (np .ndarray , type (None )))
1878+
1879+ FLOW_H , FLOW_W = 3 , 4
18821880
18831881 def inject_fake_data (self , tmpdir , config ):
18841882 root = pathlib .Path (tmpdir ) / "Sintel"
@@ -1899,14 +1897,13 @@ def inject_fake_data(self, tmpdir, config):
18991897 num_examples = num_images_per_scene ,
19001898 )
19011899
1902- # For the ground truth flow value we just create empty files so that they're properly discovered,
1903- # see comment above about EXTRA_PATCHES
19041900 flow_root = root / "training" / "flow"
19051901 for scene_id in range (num_scenes ):
19061902 scene_dir = flow_root / f"scene_{ scene_id } "
19071903 os .makedirs (scene_dir )
19081904 for i in range (num_images_per_scene - 1 ):
1909- open (str (scene_dir / f"frame_000{ i } .flo" ), "a" ).close ()
1905+ file_name = str (scene_dir / f"frame_000{ i } .flo" )
1906+ datasets_utils .make_fake_flo_file (h = self .FLOW_H , w = self .FLOW_W , file_name = file_name )
19101907
19111908 # with e.g. num_images_per_scene = 3, for a single scene with have 3 images
19121909 # which are frame_0000, frame_0001 and frame_0002
@@ -1920,7 +1917,8 @@ def test_flow(self):
19201917 with self .create_dataset (split = "train" ) as (dataset , _ ):
19211918 assert dataset ._flow_list and len (dataset ._flow_list ) == len (dataset ._image_list )
19221919 for _ , _ , flow in dataset :
1923- assert flow == self ._FAKE_FLOW
1920+ assert flow .shape == (2 , self .FLOW_H , self .FLOW_W )
1921+ np .testing .assert_allclose (flow , np .arange (flow .size ).reshape (flow .shape ))
19241922
19251923 # Make sure flow is always None for test split
19261924 with self .create_dataset (split = "test" ) as (dataset , _ ):
@@ -1929,11 +1927,11 @@ def test_flow(self):
19291927 assert flow is None
19301928
19311929 def test_bad_input (self ):
1932- with pytest .raises (ValueError , match = "split must be either " ):
1930+ with pytest .raises (ValueError , match = "Unknown value 'bad' for argument split " ):
19331931 with self .create_dataset (split = "bad" ):
19341932 pass
19351933
1936- with pytest .raises (ValueError , match = "pass_name must be either " ):
1934+ with pytest .raises (ValueError , match = "Unknown value 'bad' for argument pass_name " ):
19371935 with self .create_dataset (pass_name = "bad" ):
19381936 pass
19391937
@@ -1993,10 +1991,62 @@ def test_flow_and_valid(self):
19931991 assert valid is None
19941992
19951993 def test_bad_input (self ):
1996- with pytest .raises (ValueError , match = "split must be either " ):
1994+ with pytest .raises (ValueError , match = "Unknown value 'bad' for argument split " ):
19971995 with self .create_dataset (split = "bad" ):
19981996 pass
19991997
20001998
1999+ class FlyingChairsTestCase (datasets_utils .ImageDatasetTestCase ):
2000+ DATASET_CLASS = datasets .FlyingChairs
2001+ ADDITIONAL_CONFIGS = datasets_utils .combinations_grid (split = ("train" , "val" ))
2002+ FEATURE_TYPES = (PIL .Image .Image , PIL .Image .Image , (np .ndarray , type (None )))
2003+
2004+ FLOW_H , FLOW_W = 3 , 4
2005+
2006+ def _make_split_file (self , root , num_examples ):
2007+ # We create a fake split file here, but users are asked to download the real one from the authors website
2008+ split_ids = [1 ] * num_examples ["train" ] + [2 ] * num_examples ["val" ]
2009+ random .shuffle (split_ids )
2010+ with open (str (root / "FlyingChairs_train_val.txt" ), "w+" ) as split_file :
2011+ for split_id in split_ids :
2012+ split_file .write (f"{ split_id } \n " )
2013+
2014+ def inject_fake_data (self , tmpdir , config ):
2015+ root = pathlib .Path (tmpdir ) / "FlyingChairs"
2016+
2017+ num_examples = {"train" : 5 , "val" : 3 }
2018+ num_examples_total = sum (num_examples .values ())
2019+
2020+ datasets_utils .create_image_folder ( # img1
2021+ root ,
2022+ name = "data" ,
2023+ file_name_fn = lambda image_idx : f"00{ image_idx } _img1.ppm" ,
2024+ num_examples = num_examples_total ,
2025+ )
2026+ datasets_utils .create_image_folder ( # img2
2027+ root ,
2028+ name = "data" ,
2029+ file_name_fn = lambda image_idx : f"00{ image_idx } _img2.ppm" ,
2030+ num_examples = num_examples_total ,
2031+ )
2032+ for i in range (num_examples_total ):
2033+ file_name = str (root / "data" / f"00{ i } _flow.flo" )
2034+ datasets_utils .make_fake_flo_file (h = self .FLOW_H , w = self .FLOW_W , file_name = file_name )
2035+
2036+ self ._make_split_file (root , num_examples )
2037+
2038+ return num_examples [config ["split" ]]
2039+
2040+ @datasets_utils .test_all_configs
2041+ def test_flow (self , config ):
2042+ # Make sure flow always exists, and make sure there are as many flow values as (pairs of) images
2043+ # Also make sure the flow is properly decoded
2044+ with self .create_dataset (config = config ) as (dataset , _ ):
2045+ assert dataset ._flow_list and len (dataset ._flow_list ) == len (dataset ._image_list )
2046+ for _ , _ , flow in dataset :
2047+ assert flow .shape == (2 , self .FLOW_H , self .FLOW_W )
2048+ np .testing .assert_allclose (flow , np .arange (flow .size ).reshape (flow .shape ))
2049+
2050+
20012051if __name__ == "__main__" :
20022052 unittest .main ()
0 commit comments