Skip to content

Commit f5d09c7

Browse files
committed
Avoid changing current cudnn benchmark
1 parent e68f2a1 commit f5d09c7

File tree

5 files changed

+45
-28
lines changed

5 files changed

+45
-28
lines changed

CHANGELOG.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
9797
- Changed `pytorch_lightning.core.lightning` to `pytorch_lightning.core.module` ([#12740](https://github.com/PyTorchLightning/pytorch-lightning/pull/12740))
9898

9999

100-
-
100+
- Set `torch.backends.cudnn.benchmark=False` by default (unlike in v1.6.{0-4}) after speed and memory problems based on the data used. Please consider tuning `Trainer(benchmark)` manually. ([#TODO](https://github.com/PyTorchLightning/pytorch-lightning/pull/TODO))
101+
102+
103+
- Prevent modification of `torch.backends.cudnn.benchmark` when `benchmark` not set on the `Trainer` ([#TODO](https://github.com/PyTorchLightning/pytorch-lightning/pull/TODO))
101104

102105
### Deprecated
103106

docs/source/common/trainer.rst

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -437,21 +437,24 @@ benchmark
437437

438438
|
439439
440-
Defaults to ``True`` if :paramref:`~pytorch_lightning.trainer.Trainer.deterministic` is not set.
441-
This flag sets the ``torch.backends.cudnn.benchmark`` flag. You can read more about its impact
440+
The value (``True`` or ``False``) to set ``torch.backends.cudnn.benchmark`` to. The value for
441+
``torch.backends.cudnn.benchmark`` set in the current session will be used (``False`` if not manually set).
442+
If :paramref:`~pytorch_lightning.trainer.Trainer.deterministic` is set to ``True``, this will default to ``False``.
443+
You can read more about the interaction of ``torch.backends.cudnn.benchmark`` and ``torch.backends.cudnn.deterministic``
442444
`here <https://pytorch.org/docs/stable/notes/randomness.html#cuda-convolution-benchmarking>`__
443445

444-
This is likely to increase the speed of your system if your input sizes don't change. However, if they do, then it
445-
might make your system slower. The CUDNN auto-tuner will try to find the best algorithm for the hardware when a new
446-
input size is encountered. Read more about it `here <https://discuss.pytorch.org/t/what-does-torch-backends-cudnn-benchmark-do/5936>`__.
446+
Setting this flag to ``True`` can increase the speed of your system if your input sizes don't
447+
change. However, if they do, then it might make your system slower. The CUDNN auto-tuner will try to find the best
448+
algorithm for the hardware when a new input size is encountered. This might also increase the memory usage.
449+
Read more about it `here <https://discuss.pytorch.org/t/what-does-torch-backends-cudnn-benchmark-do/5936>`__.
447450

448451
Example::
449452

450-
# defaults to True if not deterministic (which is False by default)
451-
trainer = Trainer()
453+
# Will use whatever the current value for torch.backends.cudnn.benchmark, normally False
454+
trainer = Trainer(benchmark=None) # default
452455

453456
# you can overwrite the value
454-
trainer = Trainer(benchmark=False)
457+
trainer = Trainer(benchmark=True)
455458

456459
deterministic
457460
^^^^^^^^^^^^^

pytorch_lightning/trainer/connectors/accelerator_connector.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -148,14 +148,19 @@ def __init__(
148148
A. Class > str
149149
B. Strategy > Accelerator/precision/plugins
150150
"""
151-
if benchmark and deterministic:
152-
rank_zero_warn(
153-
"You passed `deterministic=True` and `benchmark=True`. Note that PyTorch ignores"
154-
" torch.backends.cudnn.deterministic=True when torch.backends.cudnn.benchmark=True.",
155-
)
156-
self.benchmark = not deterministic if benchmark is None else benchmark
151+
if deterministic:
152+
if benchmark is None:
153+
# Set benchmark to False to ensure determinism
154+
benchmark = False
155+
elif benchmark:
156+
rank_zero_warn(
157+
"You passed `deterministic=True` and `benchmark=True`. Note that PyTorch ignores"
158+
" torch.backends.cudnn.deterministic=True when torch.backends.cudnn.benchmark=True.",
159+
)
157160
# TODO: move to gpu accelerator
158-
torch.backends.cudnn.benchmark = self.benchmark
161+
if benchmark is not None:
162+
torch.backends.cudnn.benchmark = benchmark
163+
self.benchmark = torch.backends.cudnn.benchmark
159164
self.replace_sampler_ddp = replace_sampler_ddp
160165
self._init_deterministic(deterministic)
161166

@@ -215,13 +220,13 @@ def __init__(
215220
# 6. Instantiate Strategy - Part 2
216221
self._lazy_init_strategy()
217222

218-
def _init_deterministic(self, deterministic: Union[bool, _LITERAL_WARN]) -> None:
219-
self.deterministic = deterministic
223+
def _init_deterministic(self, deterministic: Optional[Union[bool, _LITERAL_WARN]]) -> None:
224+
self.deterministic = deterministic or False # default to False if not set
220225
if _TORCH_GREATER_EQUAL_1_11 and deterministic == "warn":
221226
torch.use_deterministic_algorithms(True, warn_only=True)
222227
else:
223-
torch.use_deterministic_algorithms(deterministic)
224-
if deterministic:
228+
torch.use_deterministic_algorithms(self.deterministic)
229+
if self.deterministic:
225230
# fixing non-deterministic part of horovod
226231
# https://github.com/PyTorchLightning/pytorch-lightning/pull/1572/files#r420279383
227232
os.environ["HOROVOD_FUSION_THRESHOLD"] = "0"

pytorch_lightning/trainer/trainer.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ def __init__(
172172
resume_from_checkpoint: Optional[Union[Path, str]] = None,
173173
profiler: Optional[Union[Profiler, str]] = None,
174174
benchmark: Optional[bool] = None,
175-
deterministic: Union[bool, _LITERAL_WARN] = False,
175+
deterministic: Optional[Union[bool, _LITERAL_WARN]] = None,
176176
reload_dataloaders_every_n_epochs: int = 0,
177177
auto_lr_find: Union[bool, str] = False,
178178
replace_sampler_ddp: bool = True,
@@ -224,9 +224,11 @@ def __init__(
224224
that only one process at a time can access them.
225225
Default: ``False``.
226226
227-
benchmark: Sets ``torch.backends.cudnn.benchmark``.
228-
Defaults to ``True`` if :paramref:`~pytorch_lightning.trainer.trainer.Trainer.deterministic`
229-
is ``False``. Overwrite to manually set a different value. Default: ``None``.
227+
benchmark: The value (``True`` or ``False``) to set ``torch.backends.cudnn.benchmark`` to.
228+
The value for ``torch.backends.cudnn.benchmark`` set in the current session will be used
229+
(``False`` if not manually set). If :paramref:`~pytorch_lightning.trainer.Trainer.deterministic` is set
230+
to ``True``, this will default to ``False``. Override to manually set a different value.
231+
Default: ``None``.
230232
231233
callbacks: Add a callback or list of callbacks.
232234
Default: ``None``.
@@ -250,8 +252,8 @@ def __init__(
250252
251253
deterministic: If ``True``, sets whether PyTorch operations must use deterministic algorithms.
252254
Set to ``"warn"`` to use deterministic algorithms whenever possible, throwing warnings on operations
253-
that don't support deterministic mode (requires Pytorch 1.11+).
254-
Default: ``False``.
255+
that don't support deterministic mode (requires Pytorch 1.11+). If not set, defaults to ``False``.
256+
Default: ``None``.
255257
256258
devices: Will be mapped to either `gpus`, `tpu_cores`, `num_processes` or `ipus`,
257259
based on the accelerator type.

tests/trainer/test_trainer.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -642,12 +642,15 @@ def test_trainer_max_steps_accumulate_batches(tmpdir):
642642
@pytest.mark.parametrize(
643643
["benchmark_", "deterministic", "expected"],
644644
[
645-
(None, False, True),
645+
(None, False, None),
646646
(None, True, False),
647+
(None, None, None),
647648
(True, False, True),
648649
(True, True, True),
649-
(False, True, False),
650+
(True, None, True),
650651
(False, False, False),
652+
(False, True, False),
653+
(False, None, False),
651654
],
652655
)
653656
def test_benchmark_option(benchmark_, deterministic, expected):
@@ -660,6 +663,7 @@ def test_benchmark_option(benchmark_, deterministic, expected):
660663
trainer = Trainer(benchmark=benchmark_, deterministic=deterministic)
661664
else:
662665
trainer = Trainer(benchmark=benchmark_, deterministic=deterministic)
666+
expected = original_val if expected is None else expected
663667
assert torch.backends.cudnn.benchmark == expected
664668
assert trainer._accelerator_connector.benchmark == expected
665669

0 commit comments

Comments
 (0)