@@ -1874,11 +1874,9 @@ def _inject_pairs(self, root, num_pairs, same):
18741874class SintelTestCase (datasets_utils .ImageDatasetTestCase ):
18751875 DATASET_CLASS = datasets .Sintel
18761876 ADDITIONAL_CONFIGS = datasets_utils .combinations_grid (split = ("train" , "test" ), pass_name = ("clean" , "final" ))
1877- # We patch the flow reader, because this would otherwise force us to generate fake (but readable) .flo files,
1878- # which is something we want to # avoid.
1879- _FAKE_FLOW = "Fake Flow"
1880- EXTRA_PATCHES = {unittest .mock .patch ("torchvision.datasets.Sintel._read_flow" , return_value = _FAKE_FLOW )}
1881- FEATURE_TYPES = (PIL .Image .Image , PIL .Image .Image , (type (_FAKE_FLOW ), type (None )))
1877+ FEATURE_TYPES = (PIL .Image .Image , PIL .Image .Image , (np .ndarray , type (None )))
1878+
1879+ FLOW_H , FLOW_W = 3 , 4
18821880
18831881 def inject_fake_data (self , tmpdir , config ):
18841882 root = pathlib .Path (tmpdir ) / "Sintel"
@@ -1899,14 +1897,13 @@ def inject_fake_data(self, tmpdir, config):
18991897 num_examples = num_images_per_scene ,
19001898 )
19011899
1902- # For the ground truth flow value we just create empty files so that they're properly discovered,
1903- # see comment above about EXTRA_PATCHES
19041900 flow_root = root / "training" / "flow"
19051901 for scene_id in range (num_scenes ):
19061902 scene_dir = flow_root / f"scene_{ scene_id } "
19071903 os .makedirs (scene_dir )
19081904 for i in range (num_images_per_scene - 1 ):
1909- open (str (scene_dir / f"frame_000{ i } .flo" ), "a" ).close ()
1905+ file_name = str (scene_dir / f"frame_000{ i } .flo" )
1906+ datasets_utils .make_fake_flo_file (h = self .FLOW_H , w = self .FLOW_W , file_name = file_name )
19101907
19111908 # with e.g. num_images_per_scene = 3, for a single scene with have 3 images
19121909 # which are frame_0000, frame_0001 and frame_0002
@@ -1920,7 +1917,8 @@ def test_flow(self):
19201917 with self .create_dataset (split = "train" ) as (dataset , _ ):
19211918 assert dataset ._flow_list and len (dataset ._flow_list ) == len (dataset ._image_list )
19221919 for _ , _ , flow in dataset :
1923- assert flow == self ._FAKE_FLOW
1920+ assert flow .shape == (2 , self .FLOW_H , self .FLOW_W )
1921+ np .testing .assert_allclose (flow , np .arange (flow .size ).reshape (flow .shape ))
19241922
19251923 # Make sure flow is always None for test split
19261924 with self .create_dataset (split = "test" ) as (dataset , _ ):
@@ -2001,11 +1999,9 @@ def test_bad_input(self):
20011999class FlyingChairsTestCase (datasets_utils .ImageDatasetTestCase ):
20022000 DATASET_CLASS = datasets .FlyingChairs
20032001 ADDITIONAL_CONFIGS = datasets_utils .combinations_grid (split = ("train" , "val" ))
2004- # We patch the flow reader, because this would otherwise force us to generate fake (but readable) .flo files,
2005- # which is something we want to avoid.
2006- _FAKE_FLOW = "Fake Flow"
2007- EXTRA_PATCHES = {unittest .mock .patch ("torchvision.datasets.FlyingChairs._read_flow" , return_value = _FAKE_FLOW )}
2008- FEATURE_TYPES = (PIL .Image .Image , PIL .Image .Image , (type (_FAKE_FLOW ), type (None )))
2002+ FEATURE_TYPES = (PIL .Image .Image , PIL .Image .Image , (np .ndarray , type (None )))
2003+
2004+ FLOW_H , FLOW_W = 3 , 4
20092005
20102006 def _make_split_file (self , root , num_examples ):
20112007 # We create a fake split file here, but users are asked to download the real one from the authors website
@@ -2033,10 +2029,9 @@ def inject_fake_data(self, tmpdir, config):
20332029 file_name_fn = lambda image_idx : f"00{ image_idx } _img2.ppm" ,
20342030 num_examples = num_examples_total ,
20352031 )
2036- # For the ground truth flow value we just create empty files so that they're properly discovered,
2037- # see comment above about EXTRA_PATCHES
20382032 for i in range (num_examples_total ):
2039- open (str (root / "data" / f"00{ i } _flow.flo" ), "a" ).close ()
2033+ file_name = str (root / "data" / f"00{ i } _flow.flo" )
2034+ datasets_utils .make_fake_flo_file (h = self .FLOW_H , w = self .FLOW_W , file_name = file_name )
20402035
20412036 self ._make_split_file (root , num_examples )
20422037
@@ -2045,10 +2040,12 @@ def inject_fake_data(self, tmpdir, config):
20452040 @datasets_utils .test_all_configs
20462041 def test_flow (self , config ):
20472042 # Make sure flow always exists, and make sure there are as many flow values as (pairs of) images
2043+ # Also make sure the flow is properly decoded
20482044 with self .create_dataset (config = config ) as (dataset , _ ):
20492045 assert dataset ._flow_list and len (dataset ._flow_list ) == len (dataset ._image_list )
20502046 for _ , _ , flow in dataset :
2051- assert flow == self ._FAKE_FLOW
2047+ assert flow .shape == (2 , self .FLOW_H , self .FLOW_W )
2048+ np .testing .assert_allclose (flow , np .arange (flow .size ).reshape (flow .shape ))
20522049
20532050
20542051if __name__ == "__main__" :
0 commit comments