Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions pl_examples/domain_templates/semantic_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,19 @@
DEFAULT_VALID_LABELS = (7, 8, 11, 12, 13, 17, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 31, 32, 33)


def _create_synth_kitti_dataset(path_dir: str, image_dims: tuple = (1024, 512)):
"""Create synthetic dataset with random images, just to simulate that the dataset have been already downloaded."""
path_dir_images = os.path.join(path_dir, KITTI.IMAGE_PATH)
path_dir_masks = os.path.join(path_dir, KITTI.MASK_PATH)
for p_dir in (path_dir_images, path_dir_masks):
os.makedirs(p_dir, exist_ok=True)
for i in range(3):
path_img = os.path.join(path_dir_images, f'dummy_kitti_{i}.png')
Image.new('RGB', image_dims).save(path_img)
path_mask = os.path.join(path_dir_masks, f'dummy_kitti_{i}.png')
Image.new('L', image_dims).save(path_mask)


class KITTI(Dataset):
"""
Class for KITTI Semantic Segmentation Benchmark dataset
Expand All @@ -53,6 +66,12 @@ class KITTI(Dataset):
In the `get_item` function, images and masks are resized to the given `img_size`, masks are
encoded using `encode_segmap`, and given `transform` (if any) are applied to the image only
(mask does not usually require transforms, but they can be implemented in a similar way).

>>> from pl_examples import DATASETS_PATH
>>> dataset_path = os.path.join(DATASETS_PATH, "Kitti")
>>> _create_synth_kitti_dataset(dataset_path, image_dims=(1024, 512))
>>> KITTI(dataset_path, 'train') # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
<...semantic_segmentation.KITTI object at ...>
"""
IMAGE_PATH = os.path.join('training', 'image_2')
MASK_PATH = os.path.join('training', 'semantic')
Expand Down Expand Up @@ -141,6 +160,23 @@ class SegModel(pl.LightningModule):
It uses the FCN ResNet50 model as an example.

Adam optimizer is used along with Cosine Annealing learning rate scheduler.

>>> from pl_examples import DATASETS_PATH
>>> dataset_path = os.path.join(DATASETS_PATH, "Kitti")
>>> _create_synth_kitti_dataset(dataset_path, image_dims=(1024, 512))
>>> SegModel(dataset_path) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
SegModel(
(net): UNet(
(layers): ModuleList(
(0): DoubleConv(...)
(1): Down(...)
(2): Down(...)
(3): Up(...)
(4): Up(...)
(5): Conv2d(64, 19, kernel_size=(1, 1), stride=(1, 1))
)
)
)
"""
def __init__(
self,
Expand Down
13 changes: 0 additions & 13 deletions pl_examples/pytorch_ecosystem/__init__.py

This file was deleted.