Skip to content

Commit dfcb62a

Browse files
Bordatchaton
andcommitted
add doctests for example 2/n segmentation (#5083)
* draft * fix * drop folder Co-authored-by: chaton <[email protected]>
1 parent 86daa38 commit dfcb62a

File tree

2 files changed

+36
-13
lines changed

2 files changed

+36
-13
lines changed

pl_examples/domain_templates/semantic_segmentation.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,19 @@
3232
DEFAULT_VALID_LABELS = (7, 8, 11, 12, 13, 17, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 31, 32, 33)
3333

3434

35+
def _create_synth_kitti_dataset(path_dir: str, image_dims: tuple = (1024, 512)):
36+
"""Create synthetic dataset with random images, just to simulate that the dataset have been already downloaded."""
37+
path_dir_images = os.path.join(path_dir, KITTI.IMAGE_PATH)
38+
path_dir_masks = os.path.join(path_dir, KITTI.MASK_PATH)
39+
for p_dir in (path_dir_images, path_dir_masks):
40+
os.makedirs(p_dir, exist_ok=True)
41+
for i in range(3):
42+
path_img = os.path.join(path_dir_images, f'dummy_kitti_{i}.png')
43+
Image.new('RGB', image_dims).save(path_img)
44+
path_mask = os.path.join(path_dir_masks, f'dummy_kitti_{i}.png')
45+
Image.new('L', image_dims).save(path_mask)
46+
47+
3548
class KITTI(Dataset):
3649
"""
3750
Class for KITTI Semantic Segmentation Benchmark dataset
@@ -53,6 +66,12 @@ class KITTI(Dataset):
5366
In the `get_item` function, images and masks are resized to the given `img_size`, masks are
5467
encoded using `encode_segmap`, and given `transform` (if any) are applied to the image only
5568
(mask does not usually require transforms, but they can be implemented in a similar way).
69+
70+
>>> from pl_examples import DATASETS_PATH
71+
>>> dataset_path = os.path.join(DATASETS_PATH, "Kitti")
72+
>>> _create_synth_kitti_dataset(dataset_path, image_dims=(1024, 512))
73+
>>> KITTI(dataset_path, 'train') # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
74+
<...semantic_segmentation.KITTI object at ...>
5675
"""
5776
IMAGE_PATH = os.path.join('training', 'image_2')
5877
MASK_PATH = os.path.join('training', 'semantic')
@@ -141,6 +160,23 @@ class SegModel(pl.LightningModule):
141160
It uses the FCN ResNet50 model as an example.
142161
143162
Adam optimizer is used along with Cosine Annealing learning rate scheduler.
163+
164+
>>> from pl_examples import DATASETS_PATH
165+
>>> dataset_path = os.path.join(DATASETS_PATH, "Kitti")
166+
>>> _create_synth_kitti_dataset(dataset_path, image_dims=(1024, 512))
167+
>>> SegModel(dataset_path) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
168+
SegModel(
169+
(net): UNet(
170+
(layers): ModuleList(
171+
(0): DoubleConv(...)
172+
(1): Down(...)
173+
(2): Down(...)
174+
(3): Up(...)
175+
(4): Up(...)
176+
(5): Conv2d(64, 19, kernel_size=(1, 1), stride=(1, 1))
177+
)
178+
)
179+
)
144180
"""
145181
def __init__(
146182
self,

pl_examples/pytorch_ecosystem/__init__.py

Lines changed: 0 additions & 13 deletions
This file was deleted.

0 commit comments

Comments
 (0)