Skip to content

Commit 21d313e

Browse files
authored
yapf examples (#5709)
1 parent 07f24d2 commit 21d313e

17 files changed

+243
-212
lines changed

.yapfignore

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,5 @@
11
.git/*
22

3-
# TODO
4-
pl_examples/*
5-
63
# TODO
74
pytorch_lightning/*
85

pl_examples/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
_TORCHVISION_AVAILABLE = _module_available("torchvision")
1010
_DALI_AVAILABLE = _module_available("nvidia.dali")
1111

12-
1312
LIGHTNING_LOGO = """
1413
####
1514
###########

pl_examples/basic_examples/autoencoder.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,12 +43,12 @@ def __init__(self):
4343
self.encoder = nn.Sequential(
4444
nn.Linear(28 * 28, 64),
4545
nn.ReLU(),
46-
nn.Linear(64, 3)
46+
nn.Linear(64, 3),
4747
)
4848
self.decoder = nn.Sequential(
4949
nn.Linear(3, 64),
5050
nn.ReLU(),
51-
nn.Linear(64, 28 * 28)
51+
nn.Linear(64, 28 * 28),
5252
)
5353

5454
def forward(self, x):

pl_examples/basic_examples/backbone_image_classifier.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ class Backbone(torch.nn.Module):
3636
(l2): Linear(...)
3737
)
3838
"""
39+
3940
def __init__(self, hidden_dim=128):
4041
super().__init__()
4142
self.l1 = torch.nn.Linear(28 * 28, hidden_dim)
@@ -55,6 +56,7 @@ class LitClassifier(pl.LightningModule):
5556
(backbone): ...
5657
)
5758
"""
59+
5860
def __init__(self, backbone, learning_rate=1e-3):
5961
super().__init__()
6062
self.save_hyperparameters()

pl_examples/basic_examples/conv_sequential_example.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,16 +39,17 @@
3939
import pl_bolts
4040
from pl_bolts.transforms.dataset_normalizations import cifar10_normalization
4141

42-
4342
#####################
4443
# Modules #
4544
#####################
4645

4746

4847
class Flatten(nn.Module):
48+
4949
def forward(self, x):
5050
return x.view(x.size(0), -1)
5151

52+
5253
###############################
5354
# LightningModule #
5455
###############################
@@ -61,6 +62,7 @@ class LitResnet(pl.LightningModule):
6162
(sequential_module): Sequential(...)
6263
)
6364
"""
65+
6466
def __init__(self, lr=0.05, batch_size=32, manual_optimization=False):
6567
super().__init__()
6668

@@ -90,9 +92,7 @@ def __init__(self, lr=0.05, batch_size=32, manual_optimization=False):
9092
nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1),
9193
nn.ReLU(inplace=False),
9294
nn.MaxPool2d(kernel_size=2, stride=2),
93-
9495
Flatten(),
95-
9696
nn.Dropout(p=0.1),
9797
nn.Linear(4096, 1024),
9898
nn.ReLU(inplace=False),
@@ -159,7 +159,8 @@ def configure_optimizers(self):
159159
optimizer,
160160
0.1,
161161
epochs=self.trainer.max_epochs,
162-
steps_per_epoch=math.ceil(45000 / self.hparams.batch_size)),
162+
steps_per_epoch=math.ceil(45000 / self.hparams.batch_size)
163+
),
163164
'interval': 'step',
164165
}
165166
}
@@ -173,6 +174,7 @@ def automatic_optimization(self) -> bool:
173174
# Instantiate Data Module #
174175
#################################
175176

177+
176178
def instantiate_datamodule(args):
177179
train_transforms = torchvision.transforms.Compose([
178180
torchvision.transforms.RandomCrop(32, padding=4),

pl_examples/basic_examples/dali_image_classifier.py

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -95,22 +95,30 @@ class DALIClassificationLoader(DALIClassificationIterator):
9595
"""
9696

9797
def __init__(
98-
self,
99-
pipelines,
100-
size=-1,
101-
reader_name=None,
102-
auto_reset=False,
103-
fill_last_batch=True,
104-
dynamic_shape=False,
105-
last_batch_padded=False,
98+
self,
99+
pipelines,
100+
size=-1,
101+
reader_name=None,
102+
auto_reset=False,
103+
fill_last_batch=True,
104+
dynamic_shape=False,
105+
last_batch_padded=False,
106106
):
107107
if NEW_DALI_API:
108108
last_batch_policy = LastBatchPolicy.FILL if fill_last_batch else LastBatchPolicy.DROP
109-
super().__init__(pipelines, size, reader_name, auto_reset, dynamic_shape,
110-
last_batch_policy=last_batch_policy, last_batch_padded=last_batch_padded)
109+
super().__init__(
110+
pipelines,
111+
size,
112+
reader_name,
113+
auto_reset,
114+
dynamic_shape,
115+
last_batch_policy=last_batch_policy,
116+
last_batch_padded=last_batch_padded
117+
)
111118
else:
112-
super().__init__(pipelines, size, reader_name, auto_reset, fill_last_batch,
113-
dynamic_shape, last_batch_padded)
119+
super().__init__(
120+
pipelines, size, reader_name, auto_reset, fill_last_batch, dynamic_shape, last_batch_padded
121+
)
114122
self._fill_last_batch = fill_last_batch
115123

116124
def __len__(self):
@@ -120,6 +128,7 @@ def __len__(self):
120128

121129

122130
class LitClassifier(pl.LightningModule):
131+
123132
def __init__(self, hidden_dim=128, learning_rate=1e-3):
124133
super().__init__()
125134
self.save_hyperparameters()

pl_examples/basic_examples/mnist_datamodule.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,10 @@ def __init__(
5858
super().__init__(*args, **kwargs)
5959
if num_workers and platform.system() == "Windows":
6060
# see: https://stackoverflow.com/a/59680818
61-
warn(f"You have requested num_workers={num_workers} on Windows,"
62-
" but currently recommended is 0, so we set it for you")
61+
warn(
62+
f"You have requested num_workers={num_workers} on Windows,"
63+
" but currently recommended is 0, so we set it for you"
64+
)
6365
num_workers = 0
6466

6567
self.dims = (1, 28, 28)
@@ -132,9 +134,9 @@ def default_transforms(self):
132134
if not _TORCHVISION_AVAILABLE:
133135
return None
134136
if self.normalize:
135-
mnist_transforms = transform_lib.Compose(
136-
[transform_lib.ToTensor(), transform_lib.Normalize(mean=(0.5,), std=(0.5,))]
137-
)
137+
mnist_transforms = transform_lib.Compose([
138+
transform_lib.ToTensor(), transform_lib.Normalize(mean=(0.5, ), std=(0.5, ))
139+
])
138140
else:
139141
mnist_transforms = transform_lib.ToTensor()
140142

pl_examples/basic_examples/simple_image_classifier.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ class LitClassifier(pl.LightningModule):
3131
(l2): Linear(...)
3232
)
3333
"""
34+
3435
def __init__(self, hidden_dim=128, learning_rate=1e-3):
3536
super().__init__()
3637
self.save_hyperparameters()

pl_examples/bug_report_model.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ class RandomDataset(Dataset):
3333
>>> RandomDataset(size=10, length=20) # doctest: +ELLIPSIS
3434
<...bug_report_model.RandomDataset object at ...>
3535
"""
36+
3637
def __init__(self, size, length):
3738
self.len = length
3839
self.data = torch.randn(length, size)
@@ -124,9 +125,11 @@ def configure_optimizers(self):
124125
# parser = ArgumentParser()
125126
# args = parser.parse_args(opt)
126127

128+
127129
def test_run():
128130

129131
class TestModel(BoringModel):
132+
130133
def on_train_epoch_start(self) -> None:
131134
print('override any method to prove your bug')
132135

pl_examples/domain_templates/computer_vision_fine_tuning.py

Lines changed: 21 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
21
# Copyright The PyTorch Lightning team.
32
#
43
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -60,14 +59,12 @@
6059

6160
DATA_URL = "https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip"
6261

63-
6462
# --- Finetunning Callback ---
6563

64+
6665
class MilestonesFinetuningCallback(BaseFinetuningCallback):
6766

68-
def __init__(self,
69-
milestones: tuple = (5, 10),
70-
train_bn: bool = True):
67+
def __init__(self, milestones: tuple = (5, 10), train_bn: bool = True):
7168
self.milestones = milestones
7269
self.train_bn = train_bn
7370

@@ -78,17 +75,13 @@ def finetunning_function(self, pl_module: pl.LightningModule, epoch: int, optimi
7875
if epoch == self.milestones[0]:
7976
# unfreeze 5 last layers
8077
self.unfreeze_and_add_param_group(
81-
module=pl_module.feature_extractor[-5:],
82-
optimizer=optimizer,
83-
train_bn=self.train_bn
78+
module=pl_module.feature_extractor[-5:], optimizer=optimizer, train_bn=self.train_bn
8479
)
8580

8681
elif epoch == self.milestones[1]:
8782
# unfreeze remaing layers
8883
self.unfreeze_and_add_param_group(
89-
module=pl_module.feature_extractor[:-5],
90-
optimizer=optimizer,
91-
train_bn=self.train_bn
84+
module=pl_module.feature_extractor[:-5], optimizer=optimizer, train_bn=self.train_bn
9285
)
9386

9487

@@ -149,10 +142,12 @@ def __build_model(self):
149142
self.feature_extractor = nn.Sequential(*_layers)
150143

151144
# 2. Classifier:
152-
_fc_layers = [nn.Linear(2048, 256),
153-
nn.ReLU(),
154-
nn.Linear(256, 32),
155-
nn.Linear(32, 1)]
145+
_fc_layers = [
146+
nn.Linear(2048, 256),
147+
nn.ReLU(),
148+
nn.Linear(256, 32),
149+
nn.Linear(32, 1),
150+
]
156151
self.fc = nn.Sequential(*_fc_layers)
157152

158153
# 3. Loss:
@@ -218,25 +213,21 @@ def setup(self, stage: str):
218213

219214
train_dataset = ImageFolder(
220215
root=data_path.joinpath("train"),
221-
transform=transforms.Compose(
222-
[
223-
transforms.Resize((224, 224)),
224-
transforms.RandomHorizontalFlip(),
225-
transforms.ToTensor(),
226-
normalize,
227-
]
228-
),
216+
transform=transforms.Compose([
217+
transforms.Resize((224, 224)),
218+
transforms.RandomHorizontalFlip(),
219+
transforms.ToTensor(),
220+
normalize,
221+
]),
229222
)
230223

231224
valid_dataset = ImageFolder(
232225
root=data_path.joinpath("validation"),
233-
transform=transforms.Compose(
234-
[
235-
transforms.Resize((224, 224)),
236-
transforms.ToTensor(),
237-
normalize,
238-
]
239-
),
226+
transform=transforms.Compose([
227+
transforms.Resize((224, 224)),
228+
transforms.ToTensor(),
229+
normalize,
230+
]),
240231
)
241232

242233
self.train_dataset = train_dataset

0 commit comments

Comments
 (0)