Skip to content

Commit 0c728f3

Browse files
committed
pep8
Co-authored-by: @awaelchi
1 parent a503631 commit 0c728f3

File tree

4 files changed

+44
-68
lines changed

4 files changed

+44
-68
lines changed

pytorch_lightning/accelerators/accelerator.py

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,14 @@ def backward(
229229

230230
return output
231231

232-
def optimizer_step(self, optimizer: torch.optim.Optimizer, current_epoch: int, batch_idx: int, opt_idx: int, lambda_closure: Callable):
232+
def optimizer_step(
233+
self,
234+
optimizer: torch.optim.Optimizer,
235+
current_epoch: int,
236+
batch_idx: int,
237+
opt_idx: int,
238+
lambda_closure: Callable,
239+
):
233240
"""performs the actual optimizer step.
234241
235242
Args:
@@ -265,16 +272,15 @@ def optimizer_step(self, optimizer: torch.optim.Optimizer, current_epoch: int, b
265272
self.training_type_plugin.post_optimizer_step(optimizer, opt_idx)
266273
return res
267274

268-
def optimizer_zero_grad(self, current_epoch: int, batch_idx: int, optimizer: torch.optim.Optimizer, opt_idx: int) -> None:
269-
"""Zeros all model parameter's gradients
270-
"""
275+
def optimizer_zero_grad(
276+
self, current_epoch: int, batch_idx: int, optimizer: torch.optim.Optimizer, opt_idx: int
277+
) -> None:
278+
"""Zeros all model parameter's gradients"""
271279
model_ref = self.lightning_module
272280
model_ref.optimizer_zero_grad(current_epoch, batch_idx, optimizer, opt_idx)
273281

274282
def clip_gradients(self, optimizer: torch.optim.Optimizer, clip_val: Union[int, float]) -> None:
275-
"""clips all the optimizer parameters to the given value
276-
277-
"""
283+
"""clips all the optimizer parameters to the given value"""
278284

279285
self.precision_plugin.clip_gradients(optimizer, clip_val)
280286

@@ -287,11 +293,10 @@ def on_train_epoch_end(self, outputs) -> None:
287293
pass
288294

289295
def on_train_end(self) -> None:
290-
"""Hook to do something at the end of the training
291-
"""
296+
"""Hook to do something at the end of the training"""
292297
pass
293298

294-
def setup_optimizers(self, trainer: 'Trainer', model: LightningModule):
299+
def setup_optimizers(self, trainer: "Trainer", model: LightningModule):
295300
"""creates optimizers and schedulers
296301
297302
Args:
@@ -306,25 +311,21 @@ def setup_optimizers(self, trainer: 'Trainer', model: LightningModule):
306311
self.optimizer_frequencies = optimizer_frequencies
307312

308313
def connect_training_type_plugin(self, plugin: TrainingTypePlugin, model: LightningModule) -> None:
309-
"""Attaches the training type plugin to the accelerator.
314+
"""Attaches the training type plugin to the accelerator.
310315
Also transfers ownership of the model to this plugin
311316
312317
"""
313318
plugin.connect(model)
314319

315320
def connect_precision_plugin(self, plugin: PrecisionPlugin):
316-
"""Attaches the precision plugin to the accelerator
317-
318-
"""
321+
"""Attaches the precision plugin to the accelerator"""
319322
model, optimizers, schedulers = plugin.connect(self.model, self.optimizers, self.lr_schedulers)
320323
self.model = model
321324
self.optimizers = optimizers
322325
self.schedulers = schedulers
323326

324327
def to_device(self, batch: Any) -> Any:
325-
"""Pushes the batch to the root device
326-
327-
"""
328+
"""Pushes the batch to the root device"""
328329
return self.batch_to_device(batch, self.root_device)
329330

330331
@property

pytorch_lightning/plugins/base_plugin.py

Lines changed: 8 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,7 @@
44

55

66
class Plugin(object):
7-
"""Basic Plugin class to derive precision and training type plugins from.
8-
"""
7+
"""Basic Plugin class to derive precision and training type plugins from."""
98

109
def connect(self, model: torch.nn.Module, *args, **kwargs):
1110
"""Connects the plugin with the accelerator (and thereby with trainer and model).
@@ -14,39 +13,32 @@ def connect(self, model: torch.nn.Module, *args, **kwargs):
1413
pass
1514

1615
def pre_optimizer_step(self, optimizer: torch.optim.Optimizer, optimizer_idx: int):
17-
"""Hook to do something before each optimizer step.
18-
"""
16+
"""Hook to do something before each optimizer step."""
1917
pass
2018

2119
def post_optimizer_step(self, optimizer: torch.optim.Optimizer, optimizer_idx: int):
22-
"""Hook to do something after each optimizer step.
23-
"""
20+
"""Hook to do something after each optimizer step."""
2421
pass
2522

2623
def pre_training(self):
27-
"""Hook to do something before the training starts.
28-
"""
24+
"""Hook to do something before the training starts."""
2925
pass
3026

3127
def post_training(self):
32-
"""Hook to do something after the training finishes.
33-
"""
28+
"""Hook to do something after the training finishes."""
3429
pass
3530

3631
@contextlib.contextmanager
3732
def train_step_context(self):
38-
"""A contextmanager for the trainstep
39-
"""
33+
"""A contextmanager for the trainstep"""
4034
yield
4135

4236
@contextlib.contextmanager
4337
def val_step_context(self):
44-
"""A contextmanager for the validation step
45-
"""
38+
"""A contextmanager for the validation step"""
4639
yield
4740

4841
@contextlib.contextmanager
4942
def test_step_context(self):
50-
"""A contextmanager for the teststep
51-
"""
43+
"""A contextmanager for the teststep"""
5244
yield

pytorch_lightning/plugins/precision/precision_plugin.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import torch
55
from torch.optim import Optimizer
66

7-
from pytorch_lightning.plugins .base_plugin import Plugin
7+
from pytorch_lightning.plugins.base_plugin import Plugin
88
from pytorch_lightning.core import LightningModule
99

1010

@@ -13,7 +13,7 @@ class PrecisionPlugin(Plugin):
1313
precision = 32
1414

1515
def master_params(self, optimizer: torch.optim.Optimizer) -> Generator[torch.Tensor, None, None]:
16-
"""The master params of the model. Returns the plain model params here.
16+
"""The master params of the model. Returns the plain model params here.
1717
Maybe different in other precision plugins.
1818
1919
"""
@@ -22,9 +22,7 @@ def master_params(self, optimizer: torch.optim.Optimizer) -> Generator[torch.Ten
2222
yield p
2323

2424
def connect(self, model: torch.nn.Module, optimizers, lr_schedulers):
25-
"""Connects this plugin to the accelerator and the training process
26-
27-
"""
25+
"""Connects this plugin to the accelerator and the training process"""
2826
return model, optimizers, lr_schedulers
2927

3028
def backward(
@@ -61,9 +59,7 @@ def backward(
6159
return closure_loss
6260

6361
def clip_gradients(self, optimizer: Optimizer, clip_val: Union[int, float], norm_type: float = float(2.0)):
64-
"""Clips the gradients to a specific value
65-
66-
"""
62+
"""Clips the gradients to a specific value"""
6763
# TODO: separate TPU case from here
6864
if clip_val is None:
6965
return

pytorch_lightning/plugins/training_type/training_type_plugin.py

Lines changed: 14 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,8 @@
1010

1111

1212
class TrainingTypePlugin(Plugin, ABC):
13-
"""A Plugin to change the behaviour of the training, validation and test-loop.
13+
"""A Plugin to change the behaviour of the training, validation and test-loop."""
1414

15-
"""
1615
def __init__(self):
1716
self._model = None
1817
self._results = None
@@ -21,46 +20,39 @@ def __init__(self):
2120
@property
2221
@abstractmethod
2322
def on_gpu(self) -> bool:
24-
"""Returns whether the current process is done on GPU
25-
"""
23+
"""Returns whether the current process is done on GPU"""
2624
raise NotImplementedError
2725

2826
@property
2927
@abstractmethod
3028
def root_device(self) -> torch.device:
31-
"""Returns the root device
32-
"""
29+
"""Returns the root device"""
3330
raise NotImplementedError
3431

3532
@abstractmethod
3633
def model_to_device(self):
37-
"""Moves the model to the correct device
38-
"""
34+
"""Moves the model to the correct device"""
3935
raise NotImplementedError
4036

4137
@property
4238
@abstractmethod
4339
def is_global_zero(self) -> bool:
44-
"""Whether the current process is the rank zero process not only on the local node, but for all nodes.
45-
"""
40+
"""Whether the current process is the rank zero process not only on the local node, but for all nodes."""
4641
raise NotImplementedError
4742

4843
@abstractmethod
4944
def reduce(self, output, *args, **kwargs):
50-
"""Reduces the given output (e.g. across GPUs/Processes)
51-
"""
45+
"""Reduces the given output (e.g. across GPUs/Processes)"""
5246
raise NotImplementedError
5347

5448
@abstractmethod
5549
def barrier(self, name: Optional[str] = None):
56-
"""Forces all possibly joined processes to wait for each other
57-
"""
50+
"""Forces all possibly joined processes to wait for each other"""
5851
raise NotImplementedError
5952

6053
@abstractmethod
6154
def broadcast(self, obj: object, src: int = 0) -> object:
62-
"""Broadcasts an object to all processes
63-
"""
55+
"""Broadcasts an object to all processes"""
6456
raise NotImplementedError
6557

6658
# TODO method this is currently unused. Check after complete refactors are pushed
@@ -72,18 +64,15 @@ def set_nvidia_flags(self, is_slurm_managing_tasks, device_ids):
7264
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
7365
all_gpu_ids = ",".join([str(x) for x in range(torch.cuda.device_count())])
7466
devices = os.environ.get("CUDA_VISIBLE_DEVICES", all_gpu_ids)
75-
log.info(f'LOCAL_RANK: {self.trainer.local_rank} - CUDA_VISIBLE_DEVICES: [{devices}]')
67+
log.info(f"LOCAL_RANK: {self.trainer.local_rank} - CUDA_VISIBLE_DEVICES: [{devices}]")
7668

7769
def reduce_early_stopping_decision(self, should_stop: bool) -> bool:
78-
"""Reduce the early stopping decision across all possibly spawned processes
79-
"""
70+
"""Reduce the early stopping decision across all possibly spawned processes"""
8071
return should_stop
8172

8273
@property
8374
def model(self) -> torch.nn.Module:
84-
"""Returns the potentially wrapped LightningModule
85-
86-
"""
75+
"""Returns the potentially wrapped LightningModule"""
8776
return self._model
8877

8978
@model.setter
@@ -92,9 +81,7 @@ def model(self, new_model: torch.nn.Module):
9281

9382
@property
9483
def lightning_module(self) -> LightningModule:
95-
"""Returns the pure LightningModule without potential wrappers
96-
97-
"""
84+
"""Returns the pure LightningModule without potential wrappers"""
9885
return self._model
9986

10087
@property
@@ -110,10 +97,10 @@ def results(self):
11097
def rpc_enabled(self) -> bool:
11198
return False
11299

113-
def start_training(self, trainer: 'Trainer') -> None:
100+
def start_training(self, trainer: "Trainer") -> None:
114101
# double dispatch to initiate the training loop
115102
self._results = trainer.train()
116103

117-
def start_testing(self, trainer: 'Trainer') -> None:
104+
def start_testing(self, trainer: "Trainer") -> None:
118105
# double dispatch to initiate the test loop
119106
self._results = trainer.run_test()

0 commit comments

Comments
 (0)