Skip to content

Commit 518d915

Browse files
Bordajustusschock
andcommitted
add doctests for example 1/n (#5079)
* define tests * fix basic * fix gans * unet * test * drop * format * fix * revert Co-authored-by: Justus Schock <[email protected]> Co-authored-by: Justus Schock <[email protected]>
1 parent 3b83666 commit 518d915

12 files changed

+288
-127
lines changed

pl_examples/basic_examples/autoencoder.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,13 @@
3131

3232

3333
class LitAutoEncoder(pl.LightningModule):
34+
"""
35+
>>> LitAutoEncoder() # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
36+
LitAutoEncoder(
37+
(encoder): ...
38+
(decoder): ...
39+
)
40+
"""
3441

3542
def __init__(self):
3643
super().__init__()

pl_examples/basic_examples/backbone_image_classifier.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,13 @@
2929

3030

3131
class Backbone(torch.nn.Module):
32+
"""
33+
>>> Backbone() # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
34+
Backbone(
35+
(l1): Linear(...)
36+
(l2): Linear(...)
37+
)
38+
"""
3239
def __init__(self, hidden_dim=128):
3340
super().__init__()
3441
self.l1 = torch.nn.Linear(28 * 28, hidden_dim)
@@ -42,6 +49,12 @@ def forward(self, x):
4249

4350

4451
class LitClassifier(pl.LightningModule):
52+
"""
53+
>>> LitClassifier(Backbone()) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
54+
LitClassifier(
55+
(backbone): ...
56+
)
57+
"""
4558
def __init__(self, backbone, learning_rate=1e-3):
4659
super().__init__()
4760
self.save_hyperparameters()

pl_examples/basic_examples/conv_sequential_example.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,12 @@ def forward(self, x):
5555

5656

5757
class LitResnet(pl.LightningModule):
58+
"""
59+
>>> LitResnet() # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
60+
LitResnet(
61+
(sequential_module): Sequential(...)
62+
)
63+
"""
5864
def __init__(self, lr=0.05, batch_size=32, manual_optimization=False):
5965
super().__init__()
6066

pl_examples/basic_examples/mnist_datamodule.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@
2929
class MNISTDataModule(LightningDataModule):
3030
"""
3131
Standard MNIST, train, val, test splits and transforms
32+
33+
>>> MNISTDataModule() # doctest: +ELLIPSIS
34+
<...mnist_datamodule.MNISTDataModule object at ...>
3235
"""
3336

3437
name = "mnist"

pl_examples/basic_examples/simple_image_classifier.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,13 @@
2424

2525

2626
class LitClassifier(pl.LightningModule):
27+
"""
28+
>>> LitClassifier() # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
29+
LitClassifier(
30+
(l1): Linear(...)
31+
(l2): Linear(...)
32+
)
33+
"""
2734
def __init__(self, hidden_dim=128, learning_rate=1e-3):
2835
super().__init__()
2936
self.save_hyperparameters()

pl_examples/bug_report_model.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,10 @@
2828

2929

3030
class RandomDataset(Dataset):
31+
"""
32+
>>> RandomDataset(size=10, length=20) # doctest: +ELLIPSIS
33+
<...bug_report_model.RandomDataset object at ...>
34+
"""
3135
def __init__(self, size, length):
3236
self.len = length
3337
self.data = torch.randn(length, size)
@@ -40,6 +44,12 @@ def __len__(self):
4044

4145

4246
class BoringModel(LightningModule):
47+
"""
48+
>>> BoringModel() # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
49+
BoringModel(
50+
(layer): Linear(...)
51+
)
52+
"""
4353

4454
def __init__(self):
4555
"""
@@ -113,10 +123,9 @@ def configure_optimizers(self):
113123
# parser = ArgumentParser()
114124
# args = parser.parse_args(opt)
115125

116-
def run_test():
126+
def test_run():
117127

118128
class TestModel(BoringModel):
119-
120129
def on_train_epoch_start(self) -> None:
121130
print('override any method to prove your bug')
122131

@@ -140,4 +149,4 @@ def on_train_epoch_start(self) -> None:
140149

141150
if __name__ == '__main__':
142151
cli_lightning_logo()
143-
run_test()
152+
test_run()

pl_examples/domain_templates/computer_vision_fine_tuning.py

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -159,20 +159,30 @@ def _unfreeze_and_add_param_group(module: Module,
159159
class TransferLearningModel(pl.LightningModule):
160160
"""Transfer Learning with pre-trained ResNet50.
161161
162-
Args:
163-
hparams: Model hyperparameters
164-
dl_path: Path where the data will be downloaded
162+
>>> with TemporaryDirectory(dir='.') as tmp_dir:
163+
... TransferLearningModel(tmp_dir) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
164+
TransferLearningModel(
165+
(feature_extractor): Sequential(...)
166+
(fc): Sequential(...)
167+
)
165168
"""
166-
def __init__(self,
167-
dl_path: Union[str, Path],
168-
backbone: str = 'resnet50',
169-
train_bn: bool = True,
170-
milestones: tuple = (5, 10),
171-
batch_size: int = 8,
172-
lr: float = 1e-2,
173-
lr_scheduler_gamma: float = 1e-1,
174-
num_workers: int = 6, **kwargs) -> None:
175-
super().__init__()
169+
def __init__(
170+
self,
171+
dl_path: Union[str, Path],
172+
backbone: str = 'resnet50',
173+
train_bn: bool = True,
174+
milestones: tuple = (5, 10),
175+
batch_size: int = 8,
176+
lr: float = 1e-2,
177+
lr_scheduler_gamma: float = 1e-1,
178+
num_workers: int = 6,
179+
**kwargs,
180+
) -> None:
181+
"""
182+
Args:
183+
dl_path: Path where the data will be downloaded
184+
"""
185+
super().__init__(**kwargs)
176186
self.dl_path = dl_path
177187
self.backbone = backbone
178188
self.train_bn = train_bn

pl_examples/domain_templates/generative_adversarial_net.py

Lines changed: 48 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,13 @@
3737

3838

3939
class Generator(nn.Module):
40-
def __init__(self, latent_dim, img_shape):
40+
"""
41+
>>> Generator(img_shape=(1, 8, 8)) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
42+
Generator(
43+
(model): Sequential(...)
44+
)
45+
"""
46+
def __init__(self, latent_dim: int = 100, img_shape: tuple = (1, 28, 28)):
4147
super().__init__()
4248
self.img_shape = img_shape
4349

@@ -64,6 +70,12 @@ def forward(self, z):
6470

6571

6672
class Discriminator(nn.Module):
73+
"""
74+
>>> Discriminator(img_shape=(1, 28, 28)) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
75+
Discriminator(
76+
(model): Sequential(...)
77+
)
78+
"""
6779
def __init__(self, img_shape):
6880
super().__init__()
6981

@@ -83,6 +95,37 @@ def forward(self, img):
8395

8496

8597
class GAN(LightningModule):
98+
"""
99+
>>> GAN(img_shape=(1, 8, 8)) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
100+
GAN(
101+
(generator): Generator(
102+
(model): Sequential(...)
103+
)
104+
(discriminator): Discriminator(
105+
(model): Sequential(...)
106+
)
107+
)
108+
"""
109+
def __init__(
110+
self,
111+
img_shape: tuple = (1, 28, 28),
112+
lr: float = 0.0002,
113+
b1: float = 0.5,
114+
b2: float = 0.999,
115+
latent_dim: int = 100,
116+
):
117+
super().__init__()
118+
119+
self.save_hyperparameters()
120+
121+
# networks
122+
self.generator = Generator(latent_dim=self.hparams.latent_dim, img_shape=img_shape)
123+
self.discriminator = Discriminator(img_shape=img_shape)
124+
125+
self.validation_z = torch.randn(8, self.hparams.latent_dim)
126+
127+
self.example_input_array = torch.zeros(2, self.hparams.latent_dim)
128+
86129
@staticmethod
87130
def add_argparse_args(parent_parser: ArgumentParser):
88131
parser = ArgumentParser(parents=[parent_parser], add_help=False)
@@ -96,20 +139,6 @@ def add_argparse_args(parent_parser: ArgumentParser):
96139

97140
return parser
98141

99-
def __init__(self, hparams: Namespace):
100-
super().__init__()
101-
102-
self.hparams = hparams
103-
104-
# networks
105-
mnist_shape = (1, 28, 28)
106-
self.generator = Generator(latent_dim=self.hparams.latent_dim, img_shape=mnist_shape)
107-
self.discriminator = Discriminator(img_shape=mnist_shape)
108-
109-
self.validation_z = torch.randn(8, self.hparams.latent_dim)
110-
111-
self.example_input_array = torch.zeros(2, self.hparams.latent_dim)
112-
113142
def forward(self, z):
114143
return self.generator(z)
115144

@@ -180,6 +209,10 @@ def on_epoch_end(self):
180209

181210

182211
class MNISTDataModule(LightningDataModule):
212+
"""
213+
>>> MNISTDataModule() # doctest: +ELLIPSIS
214+
<...generative_adversarial_net.MNISTDataModule object at ...>
215+
"""
183216
def __init__(self, batch_size: int = 64, data_path: str = os.getcwd(), num_workers: int = 4):
184217
super().__init__()
185218
self.batch_size = batch_size

pl_examples/domain_templates/imagenet.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,12 @@
5050

5151

5252
class ImageNetLightningModel(LightningModule):
53+
"""
54+
>>> ImageNetLightningModel(data_path='missing') # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
55+
ImageNetLightningModel(
56+
(model): ResNet(...)
57+
)
58+
"""
5359
# pull out resnet names from torchvision models
5460
MODEL_NAMES = sorted(
5561
name for name in models.__dict__
@@ -58,14 +64,14 @@ class ImageNetLightningModel(LightningModule):
5864

5965
def __init__(
6066
self,
61-
arch: str,
62-
pretrained: bool,
63-
lr: float,
64-
momentum: float,
65-
weight_decay: int,
6667
data_path: str,
67-
batch_size: int,
68-
workers: int,
68+
arch: str = 'resnet18',
69+
pretrained: bool = False,
70+
lr: float = 0.1,
71+
momentum: float = 0.9,
72+
weight_decay: float = 1e-4,
73+
batch_size: int = 4,
74+
workers: int = 2,
6975
**kwargs,
7076
):
7177
super().__init__()

0 commit comments

Comments
 (0)