|
9 | 9 |
|
10 | 10 | from ..io.image import _read_png_16 |
11 | 11 | from .vision import VisionDataset |
| 12 | +from .utils import verify_str_arg |
12 | 13 |
|
13 | 14 |
|
14 | 15 | __all__ = ( |
@@ -110,11 +111,8 @@ class Sintel(FlowDataset): |
110 | 111 | def __init__(self, root, split="train", pass_name="clean", transforms=None): |
111 | 112 | super().__init__(root=root, transforms=transforms) |
112 | 113 |
|
113 | | - if split not in ("train", "test"): |
114 | | - raise ValueError("split must be either 'train' or 'test'") |
115 | | - |
116 | | - if pass_name not in ("clean", "final"): |
117 | | - raise ValueError("pass_name must be either 'clean' or 'final'") |
| 114 | + verify_str_arg(split, "split", valid_values=("train", "test")) |
| 115 | + verify_str_arg(pass_name, "pass_name", valid_values=("clean", "final")) |
118 | 116 |
|
119 | 117 | root = Path(root) / "Sintel" |
120 | 118 |
|
@@ -172,8 +170,7 @@ class KittiFlow(FlowDataset): |
172 | 170 | def __init__(self, root, split="train", transforms=None): |
173 | 171 | super().__init__(root=root, transforms=transforms) |
174 | 172 |
|
175 | | - if split not in ("train", "test"): |
176 | | - raise ValueError("split must be either 'train' or 'test'") |
| 173 | + verify_str_arg(split, "split", valid_values=("train", "test")) |
177 | 174 |
|
178 | 175 | root = Path(root) / "Kitti" / (split + "ing") |
179 | 176 | images1 = sorted(glob(str(root / "image_2" / "*_10.png"))) |
@@ -238,8 +235,7 @@ class FlyingChairs(FlowDataset): |
238 | 235 | def __init__(self, root, split="train", transforms=None): |
239 | 236 | super().__init__(root=root, transforms=transforms) |
240 | 237 |
|
241 | | - if split not in ("train", "val"): |
242 | | - raise ValueError("split must be either 'train' or 'val'") |
| 238 | + verify_str_arg(split, "split", valid_values=("train", "val")) |
243 | 239 |
|
244 | 240 | root = Path(root) / "FlyingChairs" |
245 | 241 | images = sorted(glob(str(root / "data" / "*.ppm"))) |
|
0 commit comments