Skip to content

Commit d8b0bf5

Browse files
committed
Code cleaning in preparation for 7258
1 parent 7a48db5 commit d8b0bf5

File tree

12 files changed

+427
-436
lines changed

12 files changed

+427
-436
lines changed

pytorch_lightning/trainer/configuration_validator.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,39 +11,38 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from pytorch_lightning.core.lightning import LightningModule
14+
import pytorch_lightning as pl
1515
from pytorch_lightning.trainer.states import TrainerState
1616
from pytorch_lightning.utilities import rank_zero_warn
1717
from pytorch_lightning.utilities.exceptions import MisconfigurationException
1818
from pytorch_lightning.utilities.model_helpers import is_overridden
1919

2020

21-
class ConfigValidator(object):
21+
class ConfigValidator:
2222

23-
def __init__(self, trainer):
23+
def __init__(self, trainer: 'pl.Trainer') -> None:
2424
self.trainer = trainer
2525

26-
def verify_loop_configurations(self, model: LightningModule) -> None:
26+
def verify_loop_configurations(self, model: 'pl.LightningModule') -> None:
2727
r"""
2828
Checks that the model is configured correctly before the run is started.
2929
3030
Args:
3131
model: The model to check the configuration.
3232
3333
"""
34-
if self.trainer.state == TrainerState.FITTING:
34+
if self.trainer.state in (TrainerState.FITTING, TrainerState.TUNING):
3535
self.__verify_train_loop_configuration(model)
3636
self.__verify_eval_loop_configuration(model, 'val')
37-
elif self.trainer.state == TrainerState.TUNING:
38-
self.__verify_train_loop_configuration(model)
3937
elif self.trainer.state == TrainerState.VALIDATING:
4038
self.__verify_eval_loop_configuration(model, 'val')
4139
elif self.trainer.state == TrainerState.TESTING:
4240
self.__verify_eval_loop_configuration(model, 'test')
4341
elif self.trainer.state == TrainerState.PREDICTING:
4442
self.__verify_predict_loop_configuration(model)
43+
self.__verify_dp_batch_transfer_support(model)
4544

46-
def __verify_train_loop_configuration(self, model):
45+
def __verify_train_loop_configuration(self, model: 'pl.LightningModule') -> None:
4746
# -----------------------------------
4847
# verify model has a training step
4948
# -----------------------------------
@@ -82,14 +81,14 @@ def __verify_train_loop_configuration(self, model):
8281
going_to_accumulate_grad_batches = trainer.accumulation_scheduler.going_to_accumulate_grad_batches()
8382

8483
has_overriden_optimization_functions = trainer.overriden_optimizer_step or trainer.overriden_optimizer_zero_grad
85-
if (has_overriden_optimization_functions) and going_to_accumulate_grad_batches and automatic_optimization:
84+
if has_overriden_optimization_functions and going_to_accumulate_grad_batches and automatic_optimization:
8685
raise MisconfigurationException(
8786
'When overriding `LightningModule` optimizer_step or optimizer_zero_grad,'
8887
' `accumulate_grad_batches` in `Trainer` should be 1.'
8988
' It ensures optimizer_step or optimizer_zero_grad are called on every batch.'
9089
)
9190

92-
def __verify_eval_loop_configuration(self, model: LightningModule, stage: str) -> None:
91+
def __verify_eval_loop_configuration(self, model: 'pl.LightningModule', stage: str) -> None:
9392
loader_name = f'{stage}_dataloader'
9493
step_name = 'validation_step' if stage == 'val' else 'test_step'
9594

@@ -101,8 +100,15 @@ def __verify_eval_loop_configuration(self, model: LightningModule, stage: str) -
101100
if has_step and not has_loader:
102101
rank_zero_warn(f'you defined a {step_name} but have no {loader_name}. Skipping {stage} loop')
103102

104-
def __verify_predict_loop_configuration(self, model: LightningModule) -> None:
105-
103+
def __verify_predict_loop_configuration(self, model: 'pl.LightningModule') -> None:
106104
has_predict_dataloader = is_overridden('predict_dataloader', model)
107105
if not has_predict_dataloader:
108106
raise MisconfigurationException('Dataloader not found for `Trainer.predict`')
107+
108+
def __verify_dp_batch_transfer_support(self, model: 'pl.LightningModule') -> None:
109+
"""Raise Misconfiguration exception since these hooks are not supported in DP mode"""
110+
# TODO: Remove this blocker once batch transfer to device is integrated in Lightning for DP mode.
111+
batch_transfer_hooks = ('on_before_batch_transfer', 'transfer_batch_to_device', 'on_after_batch_transfer')
112+
for hook in batch_transfer_hooks:
113+
if self.trainer.accelerator_connector.use_dp and is_overridden(hook, model):
114+
raise MisconfigurationException(f'Overriding `{hook}` is not supported in DP mode.')

pytorch_lightning/trainer/connectors/data_connector.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
from torch.utils.data import DataLoader
1818

19+
import pytorch_lightning as pl
1920
from pytorch_lightning.core.datamodule import LightningDataModule
2021
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2122
from pytorch_lightning.utilities.model_helpers import is_overridden
@@ -89,7 +90,6 @@ def attach_data(self, model, train_dataloader, val_dataloaders, datamodule):
8990
# set up the passed in dataloaders (if needed)
9091
self.attach_dataloaders(model, train_dataloader, val_dataloaders)
9192
self.attach_datamodule(model, datamodule)
92-
self._validate_data_hooks(model)
9393

9494
def __enforce_datamodule_dataloader_override(self, train_dataloader, val_dataloaders, datamodule):
9595
# If you supply a datamodule you can't supply train_dataloader or val_dataloaders
@@ -98,22 +98,14 @@ def __enforce_datamodule_dataloader_override(self, train_dataloader, val_dataloa
9898
'You cannot pass train_dataloader or val_dataloaders to trainer.fit if you supply a datamodule'
9999
)
100100

101-
def _validate_data_hooks(self, model):
102-
# Raise Misconfiguration exception since these hooks are not supported in DP mode
103-
# TODO: Remove this blocker once batch transfer to device is integrated in Lightning for DP mode.
104-
batch_transfer_hooks = ('on_before_batch_transfer', 'transfer_batch_to_device', 'on_after_batch_transfer')
105-
for hook in batch_transfer_hooks:
106-
if self.trainer.accelerator_connector.use_dp and is_overridden(hook, model):
107-
raise MisconfigurationException(f'Overriding `{hook}` is not supported in DP mode.')
108-
109101
def attach_dataloaders(
110102
self,
111-
model,
103+
model: 'pl.LightningModule',
112104
train_dataloader: Optional[Union[DataLoader, List[DataLoader]]] = None,
113105
val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,
114106
test_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,
115107
predict_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,
116-
):
108+
) -> None:
117109
# when dataloader is passed via fit, patch the train_dataloader
118110
# functions to overwrite with these implementations
119111
if train_dataloader is not None:
@@ -128,7 +120,9 @@ def attach_dataloaders(
128120
if predict_dataloaders is not None:
129121
model.predict_dataloader = _PatchDataLoader(predict_dataloaders)
130122

131-
def attach_datamodule(self, model, datamodule: Optional[LightningDataModule] = None) -> None:
123+
def attach_datamodule(
124+
self, model: 'pl.LightningModule', datamodule: Optional['pl.LightningDataModule'] = None
125+
) -> None:
132126
# We use datamodule if it's been provided, otherwise we check model for it
133127
datamodule = datamodule or getattr(model, 'datamodule', None)
134128

pytorch_lightning/trainer/connectors/model_connector.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,6 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
"""
15-
Root module for all distributed operations in Lightning.
16-
Currently supports training on CPU, GPU (dp, ddp, ddp2, horovod) and TPU.
17-
18-
"""
1914
from weakref import proxy
2015

2116

pytorch_lightning/trainer/predict_loop.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -76,11 +76,7 @@ def on_predict_model_eval(self):
7676
model_ref = self.trainer.lightning_module
7777
model_ref.on_predict_model_eval()
7878

79-
def setup(self, model, max_batches, dataloaders):
80-
81-
# copy properties for forward overrides
82-
self.trainer.model_connector.copy_trainer_model_properties(model)
83-
79+
def setup(self, max_batches, dataloaders):
8480
# convert max_batches to list
8581
if isinstance(max_batches, int):
8682
max_batches = [max_batches] * len(dataloaders)

pytorch_lightning/trainer/trainer.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -775,7 +775,7 @@ def run_predict(self) -> Optional[_PREDICT_OUTPUT]:
775775
return []
776776

777777
# set up the eval loop
778-
self.predict_loop.setup(self.lightning_module, max_batches, dataloaders)
778+
self.predict_loop.setup(max_batches, dataloaders)
779779

780780
# call hook
781781
self.predict_loop.on_predict_start()
@@ -1086,8 +1086,6 @@ def tune(
10861086
Runs routines to tune hyperparameters before training.
10871087
10881088
Args:
1089-
datamodule: A instance of :class:`LightningDataModule`.
1090-
10911089
model: Model to tune.
10921090
10931091
train_dataloader: A Pytorch DataLoader with training samples. If the model has
@@ -1096,6 +1094,7 @@ def tune(
10961094
val_dataloaders: Either a single Pytorch Dataloader or a list of them, specifying validation samples.
10971095
If the model has a predefined val_dataloaders method this will be skipped
10981096
1097+
datamodule: A instance of :class:`LightningDataModule`.
10991098
"""
11001099
Trainer._log_api_event("tune")
11011100
self.state = TrainerState.TUNING

pytorch_lightning/tuner/auto_gpu_select.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,11 @@
1717

1818

1919
def pick_multiple_gpus(nb):
20-
'''
20+
"""
2121
Raises:
2222
MisconfigurationException:
2323
If ``gpus`` is set to 0, when ``auto_select_gpus=True``.
24-
'''
24+
"""
2525
if nb == 0:
2626
raise MisconfigurationException(
2727
r"auto_select_gpus=True, gpus=0 is not a valid configuration.\
@@ -38,11 +38,11 @@ def pick_multiple_gpus(nb):
3838

3939

4040
def pick_single_gpu(exclude_gpus: list):
41-
'''
41+
"""
4242
Raises:
4343
RuntimeError:
4444
If you try to allocate a GPU, when no GPUs are available.
45-
'''
45+
"""
4646
for i in range(torch.cuda.device_count()):
4747
if i in exclude_gpus:
4848
continue

pytorch_lightning/tuner/batch_size_scaling.py

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import os
1616
from typing import Optional, Tuple
1717

18-
from pytorch_lightning.core.lightning import LightningModule
18+
import pytorch_lightning as pl
1919
from pytorch_lightning.loggers.base import DummyLogger
2020
from pytorch_lightning.utilities import DeviceType, rank_zero_warn
2121
from pytorch_lightning.utilities.cloud_io import get_filesystem
@@ -28,21 +28,22 @@
2828

2929

3030
def scale_batch_size(
31-
trainer,
32-
model: LightningModule,
31+
trainer: 'pl.Trainer',
32+
model: 'pl.LightningModule',
3333
mode: str = 'power',
3434
steps_per_trial: int = 3,
3535
init_val: int = 2,
3636
max_trials: int = 25,
3737
batch_arg_name: str = 'batch_size',
3838
**fit_kwargs
39-
):
39+
) -> Optional[int]:
4040
r"""
4141
Will iteratively try to find the largest batch size for a given model
4242
that does not give an out of memory (OOM) error.
4343
4444
Args:
4545
trainer: The Trainer
46+
4647
model: Model to fit.
4748
4849
mode: string setting the search mode. Either `power` or `binsearch`.
@@ -53,7 +54,7 @@ def scale_batch_size(
5354
batch size that failed.
5455
5556
steps_per_trial: number of steps to run with a given batch size.
56-
Idealy 1 should be enough to test if a OOM error occurs,
57+
Ideally 1 should be enough to test if a OOM error occurs,
5758
however in practise a few are needed
5859
5960
init_val: initial batch size to start the search with
@@ -113,7 +114,7 @@ def scale_batch_size(
113114
trainer.progress_bar_callback.disable()
114115

115116
# Initially we just double in size until an OOM is encountered
116-
new_size = _adjust_batch_size(trainer, batch_arg_name, value=init_val) # initially set to init_val
117+
new_size, _ = _adjust_batch_size(trainer, batch_arg_name, value=init_val) # initially set to init_val
117118
if mode == 'power':
118119
new_size = _run_power_scaling(trainer, model, new_size, batch_arg_name, max_trials, **fit_kwargs)
119120
elif mode == 'binsearch':
@@ -139,7 +140,7 @@ def scale_batch_size(
139140
return new_size
140141

141142

142-
def __scale_batch_dump_params(trainer):
143+
def __scale_batch_dump_params(trainer: 'pl.Trainer') -> None:
143144
# Prevent going into infinite loop
144145
trainer.__dumped_params = {
145146
'auto_lr_find': trainer.auto_lr_find,
@@ -155,7 +156,7 @@ def __scale_batch_dump_params(trainer):
155156
}
156157

157158

158-
def __scale_batch_reset_params(trainer, model, steps_per_trial):
159+
def __scale_batch_reset_params(trainer: 'pl.Trainer', model: 'pl.LightningModule', steps_per_trial: int) -> None:
159160
trainer.auto_scale_batch_size = None # prevent recursion
160161
trainer.auto_lr_find = False # avoid lr find being called multiple times
161162
trainer.current_epoch = 0
@@ -168,7 +169,7 @@ def __scale_batch_reset_params(trainer, model, steps_per_trial):
168169
trainer.model = model # required for saving
169170

170171

171-
def __scale_batch_restore_params(trainer):
172+
def __scale_batch_restore_params(trainer: 'pl.Trainer') -> None:
172173
trainer.auto_lr_find = trainer.__dumped_params['auto_lr_find']
173174
trainer.current_epoch = trainer.__dumped_params['current_epoch']
174175
trainer.max_steps = trainer.__dumped_params['max_steps']
@@ -181,9 +182,11 @@ def __scale_batch_restore_params(trainer):
181182
del trainer.__dumped_params
182183

183184

184-
def _run_power_scaling(trainer, model, new_size, batch_arg_name, max_trials, **fit_kwargs):
185-
""" Batch scaling mode where the size is doubled at each iteration until an
186-
OOM error is encountered. """
185+
def _run_power_scaling(
186+
trainer: 'pl.Trainer', model: 'pl.LightningModule', new_size: int, batch_arg_name: str, max_trials: int,
187+
**fit_kwargs
188+
) -> int:
189+
""" Batch scaling mode where the size is doubled at each iteration until an OOM error is encountered. """
187190
for _ in range(max_trials):
188191
garbage_collection_cuda()
189192
trainer.global_step = 0 # reset after each try
@@ -207,7 +210,10 @@ def _run_power_scaling(trainer, model, new_size, batch_arg_name, max_trials, **f
207210
return new_size
208211

209212

210-
def _run_binsearch_scaling(trainer, model, new_size, batch_arg_name, max_trials, **fit_kwargs):
213+
def _run_binsearch_scaling(
214+
trainer: 'pl.Trainer', model: 'pl.LightningModule', new_size: int, batch_arg_name: str, max_trials: int,
215+
**fit_kwargs
216+
) -> int:
211217
""" Batch scaling mode where the size is initially is doubled at each iteration
212218
until an OOM error is encountered. Hereafter, the batch size is further
213219
refined using a binary search """
@@ -252,7 +258,7 @@ def _run_binsearch_scaling(trainer, model, new_size, batch_arg_name, max_trials,
252258

253259

254260
def _adjust_batch_size(
255-
trainer,
261+
trainer: 'pl.Trainer',
256262
batch_arg_name: str = 'batch_size',
257263
factor: float = 1.0,
258264
value: Optional[int] = None,

0 commit comments

Comments
 (0)