Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Changed `pytorch_lightning.core.lightning` to `pytorch_lightning.core.module` ([#12740](https://github.com/PyTorchLightning/pytorch-lightning/pull/12740))


-
- Keep `torch.backends.cudnn.benchmark=False` by default (unlike in v1.6.{0-4}) after speed and memory problems depending on the data used. Please consider tuning `Trainer(benchmark)` manually. ([#13154](https://github.com/PyTorchLightning/pytorch-lightning/pull/13154))


- Prevent modification of `torch.backends.cudnn.benchmark` when `Trainer(benchmark=...)` is not set ([#13154](https://github.com/PyTorchLightning/pytorch-lightning/pull/13154))

### Deprecated

Expand Down
19 changes: 11 additions & 8 deletions docs/source/common/trainer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -437,21 +437,24 @@ benchmark

|

Defaults to ``True`` if :paramref:`~pytorch_lightning.trainer.Trainer.deterministic` is not set.
This flag sets the ``torch.backends.cudnn.benchmark`` flag. You can read more about its impact
The value (``True`` or ``False``) to set ``torch.backends.cudnn.benchmark`` to. The value for
``torch.backends.cudnn.benchmark`` set in the current session will be used (``False`` if not manually set).
If :paramref:`~pytorch_lightning.trainer.Trainer.deterministic` is set to ``True``, this will default to ``False``.
You can read more about the interaction of ``torch.backends.cudnn.benchmark`` and ``torch.backends.cudnn.deterministic``
`here <https://pytorch.org/docs/stable/notes/randomness.html#cuda-convolution-benchmarking>`__

This is likely to increase the speed of your system if your input sizes don't change. However, if they do, then it
might make your system slower. The CUDNN auto-tuner will try to find the best algorithm for the hardware when a new
input size is encountered. Read more about it `here <https://discuss.pytorch.org/t/what-does-torch-backends-cudnn-benchmark-do/5936>`__.
Setting this flag to ``True`` can increase the speed of your system if your input sizes don't
change. However, if they do, then it might make your system slower. The CUDNN auto-tuner will try to find the best
algorithm for the hardware when a new input size is encountered. This might also increase the memory usage.
Read more about it `here <https://discuss.pytorch.org/t/what-does-torch-backends-cudnn-benchmark-do/5936>`__.

Example::

# defaults to True if not deterministic (which is False by default)
trainer = Trainer()
# Will use whatever the current value for torch.backends.cudnn.benchmark, normally False
trainer = Trainer(benchmark=None) # default

# you can overwrite the value
trainer = Trainer(benchmark=False)
trainer = Trainer(benchmark=True)

deterministic
^^^^^^^^^^^^^
Expand Down
27 changes: 16 additions & 11 deletions pytorch_lightning/trainer/connectors/accelerator_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,14 +148,19 @@ def __init__(
A. Class > str
B. Strategy > Accelerator/precision/plugins
"""
if benchmark and deterministic:
rank_zero_warn(
"You passed `deterministic=True` and `benchmark=True`. Note that PyTorch ignores"
" torch.backends.cudnn.deterministic=True when torch.backends.cudnn.benchmark=True.",
)
self.benchmark = not deterministic if benchmark is None else benchmark
if deterministic:
if benchmark is None:
# Set benchmark to False to ensure determinism
benchmark = False
elif benchmark:
rank_zero_warn(
"You passed `deterministic=True` and `benchmark=True`. Note that PyTorch ignores"
" torch.backends.cudnn.deterministic=True when torch.backends.cudnn.benchmark=True.",
)
# TODO: move to gpu accelerator
torch.backends.cudnn.benchmark = self.benchmark
if benchmark is not None:
torch.backends.cudnn.benchmark = benchmark
self.benchmark = torch.backends.cudnn.benchmark
self.replace_sampler_ddp = replace_sampler_ddp
self._init_deterministic(deterministic)

Expand Down Expand Up @@ -215,13 +220,13 @@ def __init__(
# 6. Instantiate Strategy - Part 2
self._lazy_init_strategy()

def _init_deterministic(self, deterministic: Union[bool, _LITERAL_WARN]) -> None:
self.deterministic = deterministic
def _init_deterministic(self, deterministic: Optional[Union[bool, _LITERAL_WARN]]) -> None:
self.deterministic = deterministic or False # default to False if not set
if _TORCH_GREATER_EQUAL_1_11 and deterministic == "warn":
torch.use_deterministic_algorithms(True, warn_only=True)
else:
torch.use_deterministic_algorithms(deterministic)
if deterministic:
torch.use_deterministic_algorithms(self.deterministic)
if self.deterministic:
# fixing non-deterministic part of horovod
# https://github.com/PyTorchLightning/pytorch-lightning/pull/1572/files#r420279383
os.environ["HOROVOD_FUSION_THRESHOLD"] = "0"
Expand Down
14 changes: 8 additions & 6 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def __init__(
resume_from_checkpoint: Optional[Union[Path, str]] = None,
profiler: Optional[Union[Profiler, str]] = None,
benchmark: Optional[bool] = None,
deterministic: Union[bool, _LITERAL_WARN] = False,
deterministic: Optional[Union[bool, _LITERAL_WARN]] = None,
reload_dataloaders_every_n_epochs: int = 0,
auto_lr_find: Union[bool, str] = False,
replace_sampler_ddp: bool = True,
Expand Down Expand Up @@ -223,9 +223,11 @@ def __init__(
that only one process at a time can access them.
Default: ``False``.

benchmark: Sets ``torch.backends.cudnn.benchmark``.
Defaults to ``True`` if :paramref:`~pytorch_lightning.trainer.trainer.Trainer.deterministic`
is ``False``. Overwrite to manually set a different value. Default: ``None``.
benchmark: The value (``True`` or ``False``) to set ``torch.backends.cudnn.benchmark`` to.
The value for ``torch.backends.cudnn.benchmark`` set in the current session will be used
(``False`` if not manually set). If :paramref:`~pytorch_lightning.trainer.Trainer.deterministic` is set
to ``True``, this will default to ``False``. Override to manually set a different value.
Default: ``None``.

callbacks: Add a callback or list of callbacks.
Default: ``None``.
Expand All @@ -249,8 +251,8 @@ def __init__(

deterministic: If ``True``, sets whether PyTorch operations must use deterministic algorithms.
Set to ``"warn"`` to use deterministic algorithms whenever possible, throwing warnings on operations
that don't support deterministic mode (requires Pytorch 1.11+).
Default: ``False``.
that don't support deterministic mode (requires Pytorch 1.11+). If not set, defaults to ``False``.
Default: ``None``.

devices: Will be mapped to either `gpus`, `tpu_cores`, `num_processes` or `ipus`,
based on the accelerator type.
Expand Down
13 changes: 9 additions & 4 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -639,27 +639,32 @@ def test_trainer_max_steps_accumulate_batches(tmpdir):
assert trainer.global_step == trainer.max_steps, "Model did not stop at max_steps"


@pytest.mark.parametrize("cudnn_benchmark", (False, True))
@pytest.mark.parametrize(
["benchmark_", "deterministic", "expected"],
[
(None, False, True),
(None, False, None),
(None, True, False),
(None, None, None),
(True, False, True),
(True, True, True),
(False, True, False),
(True, None, True),
(False, False, False),
(False, True, False),
(False, None, False),
],
)
def test_benchmark_option(benchmark_, deterministic, expected):
def test_benchmark_option(cudnn_benchmark, benchmark_, deterministic, expected):
"""Verify benchmark option."""

original_val = torch.backends.cudnn.benchmark

torch.backends.cudnn.benchmark = cudnn_benchmark
if benchmark_ and deterministic:
with pytest.warns(UserWarning, match="You passed `deterministic=True` and `benchmark=True`"):
trainer = Trainer(benchmark=benchmark_, deterministic=deterministic)
else:
trainer = Trainer(benchmark=benchmark_, deterministic=deterministic)
expected = cudnn_benchmark if expected is None else expected
assert torch.backends.cudnn.benchmark == expected
assert trainer._accelerator_connector.benchmark == expected

Expand Down