|
| 1 | +import itertools |
1 | 2 | import os |
2 | 3 | import re |
3 | 4 | from abc import ABC, abstractmethod |
@@ -320,31 +321,30 @@ def __init__(self, root, split="train", pass_name="clean", camera="left", transf |
320 | 321 |
|
321 | 322 | root = Path(root) / "FlyingThings3D" |
322 | 323 |
|
323 | | - for pass_name in passes: |
324 | | - for camera in cameras: |
325 | | - for direction in ["into_future", "into_past"]: |
326 | | - image_dirs = sorted(glob(str(root / pass_name / split / "*/*"))) |
327 | | - image_dirs = sorted([Path(image_dir) / camera for image_dir in image_dirs]) |
328 | | - |
329 | | - flow_dirs = sorted(glob(str(root / "optical_flow" / split / "*/*"))) |
330 | | - flow_dirs = sorted([Path(flow_dir) / direction / camera for flow_dir in flow_dirs]) |
331 | | - |
332 | | - if not image_dirs or not flow_dirs: |
333 | | - raise FileNotFoundError( |
334 | | - "Could not find the FlyingThings3D flow images. " |
335 | | - "Please make sure the directory structure is correct." |
336 | | - ) |
337 | | - |
338 | | - for image_dir, flow_dir in zip(image_dirs, flow_dirs): |
339 | | - images = sorted(glob(str(image_dir / "*.png"))) |
340 | | - flows = sorted(glob(str(flow_dir / "*.pfm"))) |
341 | | - for i in range(len(flows) - 1): |
342 | | - if direction == "into_future": |
343 | | - self._image_list += [[images[i], images[i + 1]]] |
344 | | - self._flow_list += [flows[i]] |
345 | | - elif direction == "into_past": |
346 | | - self._image_list += [[images[i + 1], images[i]]] |
347 | | - self._flow_list += [flows[i + 1]] |
| 324 | + directions = ("into_future", "into_past") |
| 325 | + for pass_name, camera, direction in itertools.product(passes, cameras, directions): |
| 326 | + image_dirs = sorted(glob(str(root / pass_name / split / "*/*"))) |
| 327 | + image_dirs = sorted([Path(image_dir) / camera for image_dir in image_dirs]) |
| 328 | + |
| 329 | + flow_dirs = sorted(glob(str(root / "optical_flow" / split / "*/*"))) |
| 330 | + flow_dirs = sorted([Path(flow_dir) / direction / camera for flow_dir in flow_dirs]) |
| 331 | + |
| 332 | + if not image_dirs or not flow_dirs: |
| 333 | + raise FileNotFoundError( |
| 334 | + "Could not find the FlyingThings3D flow images. " |
| 335 | + "Please make sure the directory structure is correct." |
| 336 | + ) |
| 337 | + |
| 338 | + for image_dir, flow_dir in zip(image_dirs, flow_dirs): |
| 339 | + images = sorted(glob(str(image_dir / "*.png"))) |
| 340 | + flows = sorted(glob(str(flow_dir / "*.pfm"))) |
| 341 | + for i in range(len(flows) - 1): |
| 342 | + if direction == "into_future": |
| 343 | + self._image_list += [[images[i], images[i + 1]]] |
| 344 | + self._flow_list += [flows[i]] |
| 345 | + elif direction == "into_past": |
| 346 | + self._image_list += [[images[i + 1], images[i]]] |
| 347 | + self._flow_list += [flows[i + 1]] |
348 | 348 |
|
349 | 349 | def __getitem__(self, index): |
350 | 350 | """Return example at given index. |
|
0 commit comments