3232DEFAULT_VALID_LABELS = (7 , 8 , 11 , 12 , 13 , 17 , 19 , 20 , 21 , 22 , 23 , 24 , 25 , 26 , 27 , 28 , 31 , 32 , 33 )
3333
3434
35+ def _create_synth_kitti_dataset (path_dir : str , image_dims : tuple = (1024 , 512 )):
36+ """Create synthetic dataset with random images, just to simulate that the dataset have been already downloaded."""
37+ path_dir_images = os .path .join (path_dir , KITTI .IMAGE_PATH )
38+ path_dir_masks = os .path .join (path_dir , KITTI .MASK_PATH )
39+ for p_dir in (path_dir_images , path_dir_masks ):
40+ os .makedirs (p_dir , exist_ok = True )
41+ for i in range (3 ):
42+ path_img = os .path .join (path_dir_images , f'dummy_kitti_{ i } .png' )
43+ Image .new ('RGB' , image_dims ).save (path_img )
44+ path_mask = os .path .join (path_dir_masks , f'dummy_kitti_{ i } .png' )
45+ Image .new ('L' , image_dims ).save (path_mask )
46+
47+
3548class KITTI (Dataset ):
3649 """
3750 Class for KITTI Semantic Segmentation Benchmark dataset
@@ -53,6 +66,12 @@ class KITTI(Dataset):
5366 In the `get_item` function, images and masks are resized to the given `img_size`, masks are
5467 encoded using `encode_segmap`, and given `transform` (if any) are applied to the image only
5568 (mask does not usually require transforms, but they can be implemented in a similar way).
69+
70+ >>> from pl_examples import DATASETS_PATH
71+ >>> dataset_path = os.path.join(DATASETS_PATH, "Kitti")
72+ >>> _create_synth_kitti_dataset(dataset_path, image_dims=(1024, 512))
73+ >>> KITTI(dataset_path, 'train') # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
74+ <...semantic_segmentation.KITTI object at ...>
5675 """
5776 IMAGE_PATH = os .path .join ('training' , 'image_2' )
5877 MASK_PATH = os .path .join ('training' , 'semantic' )
@@ -141,6 +160,23 @@ class SegModel(pl.LightningModule):
141160 It uses the FCN ResNet50 model as an example.
142161
143162 Adam optimizer is used along with Cosine Annealing learning rate scheduler.
163+
164+ >>> from pl_examples import DATASETS_PATH
165+ >>> dataset_path = os.path.join(DATASETS_PATH, "Kitti")
166+ >>> _create_synth_kitti_dataset(dataset_path, image_dims=(1024, 512))
167+ >>> SegModel(dataset_path) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
168+ SegModel(
169+ (net): UNet(
170+ (layers): ModuleList(
171+ (0): DoubleConv(...)
172+ (1): Down(...)
173+ (2): Down(...)
174+ (3): Up(...)
175+ (4): Up(...)
176+ (5): Conv2d(64, 19, kernel_size=(1, 1), stride=(1, 1))
177+ )
178+ )
179+ )
144180 """
145181 def __init__ (
146182 self ,
0 commit comments