Skip to content

Commit 53283c2

Browse files
committed
Create fake flo file for more robust testing
1 parent de49c75 commit 53283c2

File tree

2 files changed

+24
-18
lines changed

2 files changed

+24
-18
lines changed

test/datasets_utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import random
99
import shutil
1010
import string
11+
import struct
1112
import tarfile
1213
import unittest
1314
import unittest.mock
@@ -922,3 +923,11 @@ def create_random_string(length: int, *digits: str) -> str:
922923
digits = "".join(itertools.chain(*digits))
923924

924925
return "".join(random.choice(digits) for _ in range(length))
926+
927+
928+
def make_fake_flo_file(h, w, file_name):
929+
"""Creates a fake flow file in .flo format."""
930+
values = list(range(2 * h * w))
931+
content = b"PIEH" + struct.pack("i", w) + struct.pack("i", h) + struct.pack("f" * len(values), *values)
932+
with open(file_name, "wb") as f:
933+
f.write(content)

test/test_datasets.py

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1874,11 +1874,9 @@ def _inject_pairs(self, root, num_pairs, same):
18741874
class SintelTestCase(datasets_utils.ImageDatasetTestCase):
18751875
DATASET_CLASS = datasets.Sintel
18761876
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "test"), pass_name=("clean", "final"))
1877-
# We patch the flow reader, because this would otherwise force us to generate fake (but readable) .flo files,
1878-
# which is something we want to # avoid.
1879-
_FAKE_FLOW = "Fake Flow"
1880-
EXTRA_PATCHES = {unittest.mock.patch("torchvision.datasets.Sintel._read_flow", return_value=_FAKE_FLOW)}
1881-
FEATURE_TYPES = (PIL.Image.Image, PIL.Image.Image, (type(_FAKE_FLOW), type(None)))
1877+
FEATURE_TYPES = (PIL.Image.Image, PIL.Image.Image, (np.ndarray, type(None)))
1878+
1879+
FLOW_H, FLOW_W = 3, 4
18821880

18831881
def inject_fake_data(self, tmpdir, config):
18841882
root = pathlib.Path(tmpdir) / "Sintel"
@@ -1899,14 +1897,13 @@ def inject_fake_data(self, tmpdir, config):
18991897
num_examples=num_images_per_scene,
19001898
)
19011899

1902-
# For the ground truth flow value we just create empty files so that they're properly discovered,
1903-
# see comment above about EXTRA_PATCHES
19041900
flow_root = root / "training" / "flow"
19051901
for scene_id in range(num_scenes):
19061902
scene_dir = flow_root / f"scene_{scene_id}"
19071903
os.makedirs(scene_dir)
19081904
for i in range(num_images_per_scene - 1):
1909-
open(str(scene_dir / f"frame_000{i}.flo"), "a").close()
1905+
file_name = str(scene_dir / f"frame_000{i}.flo")
1906+
datasets_utils.make_fake_flo_file(h=self.FLOW_H, w=self.FLOW_W, file_name=file_name)
19101907

19111908
# with e.g. num_images_per_scene = 3, for a single scene with have 3 images
19121909
# which are frame_0000, frame_0001 and frame_0002
@@ -1920,7 +1917,8 @@ def test_flow(self):
19201917
with self.create_dataset(split="train") as (dataset, _):
19211918
assert dataset._flow_list and len(dataset._flow_list) == len(dataset._image_list)
19221919
for _, _, flow in dataset:
1923-
assert flow == self._FAKE_FLOW
1920+
assert flow.shape == (2, self.FLOW_H, self.FLOW_W)
1921+
np.testing.assert_allclose(flow, np.arange(flow.size).reshape(flow.shape))
19241922

19251923
# Make sure flow is always None for test split
19261924
with self.create_dataset(split="test") as (dataset, _):
@@ -2001,11 +1999,9 @@ def test_bad_input(self):
20011999
class FlyingChairsTestCase(datasets_utils.ImageDatasetTestCase):
20022000
DATASET_CLASS = datasets.FlyingChairs
20032001
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "val"))
2004-
# We patch the flow reader, because this would otherwise force us to generate fake (but readable) .flo files,
2005-
# which is something we want to avoid.
2006-
_FAKE_FLOW = "Fake Flow"
2007-
EXTRA_PATCHES = {unittest.mock.patch("torchvision.datasets.FlyingChairs._read_flow", return_value=_FAKE_FLOW)}
2008-
FEATURE_TYPES = (PIL.Image.Image, PIL.Image.Image, (type(_FAKE_FLOW), type(None)))
2002+
FEATURE_TYPES = (PIL.Image.Image, PIL.Image.Image, (np.ndarray, type(None)))
2003+
2004+
FLOW_H, FLOW_W = 3, 4
20092005

20102006
def _make_split_file(self, root, num_examples):
20112007
# We create a fake split file here, but users are asked to download the real one from the authors website
@@ -2033,10 +2029,9 @@ def inject_fake_data(self, tmpdir, config):
20332029
file_name_fn=lambda image_idx: f"00{image_idx}_img2.ppm",
20342030
num_examples=num_examples_total,
20352031
)
2036-
# For the ground truth flow value we just create empty files so that they're properly discovered,
2037-
# see comment above about EXTRA_PATCHES
20382032
for i in range(num_examples_total):
2039-
open(str(root / "data" / f"00{i}_flow.flo"), "a").close()
2033+
file_name = str(root / "data" / f"00{i}_flow.flo")
2034+
datasets_utils.make_fake_flo_file(h=self.FLOW_H, w=self.FLOW_W, file_name=file_name)
20402035

20412036
self._make_split_file(root, num_examples)
20422037

@@ -2045,10 +2040,12 @@ def inject_fake_data(self, tmpdir, config):
20452040
@datasets_utils.test_all_configs
20462041
def test_flow(self, config):
20472042
# Make sure flow always exists, and make sure there are as many flow values as (pairs of) images
2043+
# Also make sure the flow is properly decoded
20482044
with self.create_dataset(config=config) as (dataset, _):
20492045
assert dataset._flow_list and len(dataset._flow_list) == len(dataset._image_list)
20502046
for _, _, flow in dataset:
2051-
assert flow == self._FAKE_FLOW
2047+
assert flow.shape == (2, self.FLOW_H, self.FLOW_W)
2048+
np.testing.assert_allclose(flow, np.arange(flow.size).reshape(flow.shape))
20522049

20532050

20542051
if __name__ == "__main__":

0 commit comments

Comments
 (0)