|
19 | 19 | "Sintel", |
20 | 20 | "FlyingThings3D", |
21 | 21 | "FlyingChairs", |
| 22 | + "HD1K", |
22 | 23 | ) |
23 | 24 |
|
24 | 25 |
|
@@ -363,6 +364,73 @@ def _read_flow(self, file_name): |
363 | 364 | return _read_pfm(file_name) |
364 | 365 |
|
365 | 366 |
|
| 367 | +class HD1K(FlowDataset): |
| 368 | + """`HD1K <http://hci-benchmark.iwr.uni-heidelberg.de/>`__ dataset for optical flow. |
| 369 | +
|
| 370 | + The dataset is expected to have the following structure: :: |
| 371 | +
|
| 372 | + root |
| 373 | + hd1k |
| 374 | + hd1k_challenge |
| 375 | + image_2 |
| 376 | + hd1k_flow_gt |
| 377 | + flow_occ |
| 378 | + hd1k_input |
| 379 | + image_2 |
| 380 | +
|
| 381 | + Args: |
| 382 | + root (string): Root directory of the HD1K Dataset. |
| 383 | + split (string, optional): The dataset split, either "train" (default) or "test" |
| 384 | + transforms (callable, optional): A function/transform that takes in |
| 385 | + ``img1, img2, flow, valid`` and returns a transformed version. |
| 386 | + """ |
| 387 | + |
| 388 | + _has_builtin_flow_mask = True |
| 389 | + |
| 390 | + def __init__(self, root, split="train", transforms=None): |
| 391 | + super().__init__(root=root, transforms=transforms) |
| 392 | + |
| 393 | + verify_str_arg(split, "split", valid_values=("train", "test")) |
| 394 | + |
| 395 | + root = Path(root) / "hd1k" |
| 396 | + if split == "train": |
| 397 | + # There are 36 "sequences" and we don't want seq i to overlap with seq i + 1, so we need this for loop |
| 398 | + for seq_idx in range(36): |
| 399 | + flows = sorted(glob(str(root / "hd1k_flow_gt" / "flow_occ" / f"{seq_idx:06d}_*.png"))) |
| 400 | + images = sorted(glob(str(root / "hd1k_input" / "image_2" / f"{seq_idx:06d}_*.png"))) |
| 401 | + for i in range(len(flows) - 1): |
| 402 | + self._flow_list += [flows[i]] |
| 403 | + self._image_list += [[images[i], images[i + 1]]] |
| 404 | + else: |
| 405 | + images1 = sorted(glob(str(root / "hd1k_challenge" / "image_2" / "*10.png"))) |
| 406 | + images2 = sorted(glob(str(root / "hd1k_challenge" / "image_2" / "*11.png"))) |
| 407 | + for image1, image2 in zip(images1, images2): |
| 408 | + self._image_list += [[image1, image2]] |
| 409 | + |
| 410 | + if not self._image_list: |
| 411 | + raise FileNotFoundError( |
| 412 | + "Could not find the HD1K images. Please make sure the directory structure is correct." |
| 413 | + ) |
| 414 | + |
| 415 | + def _read_flow(self, file_name): |
| 416 | + return _read_16bits_png_with_flow_and_valid_mask(file_name) |
| 417 | + |
| 418 | + def __getitem__(self, index): |
| 419 | + """Return example at given index. |
| 420 | +
|
| 421 | + Args: |
| 422 | + index(int): The index of the example to retrieve |
| 423 | +
|
| 424 | + Returns: |
| 425 | + tuple: If ``split="train"`` a 4-tuple with ``(img1, img2, flow, |
| 426 | + valid)`` where ``valid`` is a numpy boolean mask of shape (H, W) |
| 427 | + indicating which flow values are valid. The flow is a numpy array of |
| 428 | + shape (2, H, W) and the images are PIL images. If `split="test"`, a |
| 429 | + 4-tuple with ``(img1, img2, None, None)`` is returned. |
| 430 | + """ |
| 431 | + return super().__getitem__(index) |
| 432 | + |
| 433 | + |
366 | 434 | def _read_flo(file_name): |
367 | 435 | """Read .flo file in Middlebury format""" |
368 | 436 | # Code adapted from: |
|
0 commit comments