Skip to content

Commit cd5237f

Browse files
committed
Use verify_str_arg
1 parent 53283c2 commit cd5237f

File tree

2 files changed

+8
-12
lines changed

2 files changed

+8
-12
lines changed

test/test_datasets.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1927,11 +1927,11 @@ def test_flow(self):
19271927
assert flow is None
19281928

19291929
def test_bad_input(self):
1930-
with pytest.raises(ValueError, match="split must be either"):
1930+
with pytest.raises(ValueError, match="Unknown value 'bad' for argument split"):
19311931
with self.create_dataset(split="bad"):
19321932
pass
19331933

1934-
with pytest.raises(ValueError, match="pass_name must be either"):
1934+
with pytest.raises(ValueError, match="Unknown value 'bad' for argument pass_name"):
19351935
with self.create_dataset(pass_name="bad"):
19361936
pass
19371937

@@ -1991,7 +1991,7 @@ def test_flow_and_valid(self):
19911991
assert valid is None
19921992

19931993
def test_bad_input(self):
1994-
with pytest.raises(ValueError, match="split must be either"):
1994+
with pytest.raises(ValueError, match="Unknown value 'bad' for argument split"):
19951995
with self.create_dataset(split="bad"):
19961996
pass
19971997

torchvision/datasets/_optical_flow.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
from ..io.image import _read_png_16
1111
from .vision import VisionDataset
12+
from .utils import verify_str_arg
1213

1314

1415
__all__ = (
@@ -110,11 +111,8 @@ class Sintel(FlowDataset):
110111
def __init__(self, root, split="train", pass_name="clean", transforms=None):
111112
super().__init__(root=root, transforms=transforms)
112113

113-
if split not in ("train", "test"):
114-
raise ValueError("split must be either 'train' or 'test'")
115-
116-
if pass_name not in ("clean", "final"):
117-
raise ValueError("pass_name must be either 'clean' or 'final'")
114+
verify_str_arg(split, "split", valid_values=("train", "test"))
115+
verify_str_arg(pass_name, "pass_name", valid_values=("clean", "final"))
118116

119117
root = Path(root) / "Sintel"
120118

@@ -172,8 +170,7 @@ class KittiFlow(FlowDataset):
172170
def __init__(self, root, split="train", transforms=None):
173171
super().__init__(root=root, transforms=transforms)
174172

175-
if split not in ("train", "test"):
176-
raise ValueError("split must be either 'train' or 'test'")
173+
verify_str_arg(split, "split", valid_values=("train", "test"))
177174

178175
root = Path(root) / "Kitti" / (split + "ing")
179176
images1 = sorted(glob(str(root / "image_2" / "*_10.png")))
@@ -238,8 +235,7 @@ class FlyingChairs(FlowDataset):
238235
def __init__(self, root, split="train", transforms=None):
239236
super().__init__(root=root, transforms=transforms)
240237

241-
if split not in ("train", "val"):
242-
raise ValueError("split must be either 'train' or 'val'")
238+
verify_str_arg(split, "split", valid_values=("train", "val"))
243239

244240
root = Path(root) / "FlyingChairs"
245241
images = sorted(glob(str(root / "data" / "*.ppm")))

0 commit comments

Comments
 (0)