Skip to content

Commit a18df32

Browse files
authored
Merge branch 'main' into ci/android_app
2 parents 7b11326 + 0fa747e commit a18df32

File tree

8 files changed

+125
-11
lines changed

8 files changed

+125
-11
lines changed

docs/source/datasets.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ You can also create your own datasets using the provided :ref:`base classes <bas
4545
Flickr30k
4646
FlyingChairs
4747
FlyingThings3D
48+
HD1K
4849
HMDB51
4950
ImageNet
5051
INaturalist

docs/source/models.rst

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -323,10 +323,6 @@ Inception v3
323323

324324
inception_v3
325325

326-
.. note ::
327-
This requires `scipy` to be installed
328-
329-
330326
GoogLeNet
331327
------------
332328

@@ -336,10 +332,6 @@ GoogLeNet
336332

337333
googlenet
338334

339-
.. note ::
340-
This requires `scipy` to be installed
341-
342-
343335
ShuffleNet v2
344336
-------------
345337

packaging/windows/internal/cuda_install.bat

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ if not exist "%SRC_DIR%\temp_build\%CUDA_INSTALL_EXE%" (
167167
curl -k -L "https://ossci-windows.s3.amazonaws.com/%CUDA_INSTALL_EXE%" --output "%SRC_DIR%\temp_build\%CUDA_INSTALL_EXE%"
168168
if errorlevel 1 exit /b 1
169169
set "CUDA_SETUP_FILE=%SRC_DIR%\temp_build\%CUDA_INSTALL_EXE%"
170-
set "ARGS=thrust_11.3 nvcc_11.3 cuobjdump_11.3 nvprune_11.3 nvprof_11.3 cupti_11.3 cublas_11.3 cublas_dev_11.3 cudart_11.3 cufft_11.3 cufft_dev_11.3 curand_11.3 curand_dev_11.3 cusolver_11.3 cusolver_dev_11.3 cusparse_11.3 cusparse_dev_11.3 npp_11.3 npp_dev_11.3 nvrtc_11.3 nvrtc_dev_11.3 nvml_dev_11.3"
170+
set "ARGS=thrust_11.3 nvcc_11.3 cuobjdump_11.3 nvprune_11.3 nvprof_11.3 cupti_11.3 cublas_11.3 cublas_dev_11.3 cudart_11.3 cufft_11.3 cufft_dev_11.3 curand_11.3 curand_dev_11.3 cusolver_11.3 cusolver_dev_11.3 cusparse_11.3 cusparse_dev_11.3 npp_11.3 npp_dev_11.3 nvjpeg_11.3 nvjpeg_dev_11.3 nvrtc_11.3 nvrtc_dev_11.3 nvml_dev_11.3"
171171

172172
)
173173

references/classification/train.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,8 @@ def main(args):
325325
args.start_epoch = checkpoint["epoch"] + 1
326326
if model_ema:
327327
model_ema.load_state_dict(checkpoint["model_ema"])
328+
if scaler:
329+
scaler.load_state_dict(checkpoint["scaler"])
328330

329331
if args.test_only:
330332
# We disable the cudnn benchmarking because it can noticeably affect the accuracy
@@ -356,6 +358,8 @@ def main(args):
356358
}
357359
if model_ema:
358360
checkpoint["model_ema"] = model_ema.state_dict()
361+
if scaler:
362+
checkpoint["scaler"] = scaler.state_dict()
359363
utils.save_on_master(checkpoint, os.path.join(args.output_dir, f"model_{epoch}.pth"))
360364
utils.save_on_master(checkpoint, os.path.join(args.output_dir, "checkpoint.pth"))
361365

test/builtin_dataset_mocks.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,10 +100,16 @@ def _get(self, dataset, config):
100100
return mock_resources, mock_info
101101

102102
def _decoder(self, dataset_type):
103+
def to_bytes(file):
104+
try:
105+
return file.read()
106+
finally:
107+
file.close()
108+
103109
if dataset_type == datasets.utils.DatasetType.RAW:
104110
return datasets.decoder.raw
105111
else:
106-
return lambda file: file.close()
112+
return to_bytes
107113

108114
def load(
109115
self, name: str, decoder=DEFAULT_TEST_DECODER, split="train", **options: Any

test/test_datasets.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2126,5 +2126,47 @@ def test_bad_input(self):
21262126
pass
21272127

21282128

2129+
class HD1KTestCase(KittiFlowTestCase):
2130+
DATASET_CLASS = datasets.HD1K
2131+
2132+
def inject_fake_data(self, tmpdir, config):
2133+
root = pathlib.Path(tmpdir) / "hd1k"
2134+
2135+
num_sequences = 4 if config["split"] == "train" else 3
2136+
num_examples_per_train_sequence = 3
2137+
2138+
for seq_idx in range(num_sequences):
2139+
# Training data
2140+
datasets_utils.create_image_folder(
2141+
root / "hd1k_input",
2142+
name="image_2",
2143+
file_name_fn=lambda image_idx: f"{seq_idx:06d}_{image_idx}.png",
2144+
num_examples=num_examples_per_train_sequence,
2145+
)
2146+
datasets_utils.create_image_folder(
2147+
root / "hd1k_flow_gt",
2148+
name="flow_occ",
2149+
file_name_fn=lambda image_idx: f"{seq_idx:06d}_{image_idx}.png",
2150+
num_examples=num_examples_per_train_sequence,
2151+
)
2152+
2153+
# Test data
2154+
datasets_utils.create_image_folder(
2155+
root / "hd1k_challenge",
2156+
name="image_2",
2157+
file_name_fn=lambda _: f"{seq_idx:06d}_10.png",
2158+
num_examples=1,
2159+
)
2160+
datasets_utils.create_image_folder(
2161+
root / "hd1k_challenge",
2162+
name="image_2",
2163+
file_name_fn=lambda _: f"{seq_idx:06d}_11.png",
2164+
num_examples=1,
2165+
)
2166+
2167+
num_examples_per_sequence = num_examples_per_train_sequence if config["split"] == "train" else 2
2168+
return num_sequences * (num_examples_per_sequence - 1)
2169+
2170+
21292171
if __name__ == "__main__":
21302172
unittest.main()

torchvision/datasets/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from ._optical_flow import KittiFlow, Sintel, FlyingChairs, FlyingThings3D
1+
from ._optical_flow import KittiFlow, Sintel, FlyingChairs, FlyingThings3D, HD1K
22
from .caltech import Caltech101, Caltech256
33
from .celeba import CelebA
44
from .cifar import CIFAR10, CIFAR100
@@ -76,4 +76,5 @@
7676
"Sintel",
7777
"FlyingChairs",
7878
"FlyingThings3D",
79+
"HD1K",
7980
)

torchvision/datasets/_optical_flow.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
"Sintel",
2020
"FlyingThings3D",
2121
"FlyingChairs",
22+
"HD1K",
2223
)
2324

2425

@@ -363,6 +364,73 @@ def _read_flow(self, file_name):
363364
return _read_pfm(file_name)
364365

365366

367+
class HD1K(FlowDataset):
368+
"""`HD1K <http://hci-benchmark.iwr.uni-heidelberg.de/>`__ dataset for optical flow.
369+
370+
The dataset is expected to have the following structure: ::
371+
372+
root
373+
hd1k
374+
hd1k_challenge
375+
image_2
376+
hd1k_flow_gt
377+
flow_occ
378+
hd1k_input
379+
image_2
380+
381+
Args:
382+
root (string): Root directory of the HD1K Dataset.
383+
split (string, optional): The dataset split, either "train" (default) or "test"
384+
transforms (callable, optional): A function/transform that takes in
385+
``img1, img2, flow, valid`` and returns a transformed version.
386+
"""
387+
388+
_has_builtin_flow_mask = True
389+
390+
def __init__(self, root, split="train", transforms=None):
391+
super().__init__(root=root, transforms=transforms)
392+
393+
verify_str_arg(split, "split", valid_values=("train", "test"))
394+
395+
root = Path(root) / "hd1k"
396+
if split == "train":
397+
# There are 36 "sequences" and we don't want seq i to overlap with seq i + 1, so we need this for loop
398+
for seq_idx in range(36):
399+
flows = sorted(glob(str(root / "hd1k_flow_gt" / "flow_occ" / f"{seq_idx:06d}_*.png")))
400+
images = sorted(glob(str(root / "hd1k_input" / "image_2" / f"{seq_idx:06d}_*.png")))
401+
for i in range(len(flows) - 1):
402+
self._flow_list += [flows[i]]
403+
self._image_list += [[images[i], images[i + 1]]]
404+
else:
405+
images1 = sorted(glob(str(root / "hd1k_challenge" / "image_2" / "*10.png")))
406+
images2 = sorted(glob(str(root / "hd1k_challenge" / "image_2" / "*11.png")))
407+
for image1, image2 in zip(images1, images2):
408+
self._image_list += [[image1, image2]]
409+
410+
if not self._image_list:
411+
raise FileNotFoundError(
412+
"Could not find the HD1K images. Please make sure the directory structure is correct."
413+
)
414+
415+
def _read_flow(self, file_name):
416+
return _read_16bits_png_with_flow_and_valid_mask(file_name)
417+
418+
def __getitem__(self, index):
419+
"""Return example at given index.
420+
421+
Args:
422+
index(int): The index of the example to retrieve
423+
424+
Returns:
425+
tuple: If ``split="train"`` a 4-tuple with ``(img1, img2, flow,
426+
valid)`` where ``valid`` is a numpy boolean mask of shape (H, W)
427+
indicating which flow values are valid. The flow is a numpy array of
428+
shape (2, H, W) and the images are PIL images. If `split="test"`, a
429+
4-tuple with ``(img1, img2, None, None)`` is returned.
430+
"""
431+
return super().__getitem__(index)
432+
433+
366434
def _read_flo(file_name):
367435
"""Read .flo file in Middlebury format"""
368436
# Code adapted from:

0 commit comments

Comments
 (0)