Skip to content

Commit 0d3325e

Browse files
authored
Add support for torch.use_deterministic_algorithms (#9121)
* re-add changes * Update test_data_parallel.py * Update CHANGELOG.md * Update test_legacy_checkpoints.py * Update test_horovod.py * Update test_horovod.py * Update accelerator_connector.py * update tests
1 parent fb81e73 commit 0d3325e

File tree

9 files changed

+60
-43
lines changed

9 files changed

+60
-43
lines changed

CHANGELOG.md

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
163163
- Added `pl_legacy_patch` load utility for loading old checkpoints that have pickled legacy Lightning attributes ([#9166](https://github.com/PyTorchLightning/pytorch-lightning/pull/9166))
164164

165165

166+
- Added support for `torch.use_deterministic_algorithms` ([#9121](https://github.com/PyTorchLightning/pytorch-lightning/pull/9121))
167+
168+
166169
### Changed
167170

168171
- `pytorch_lightning.loggers.neptune.NeptuneLogger` is now consistent with new [neptune-client](https://github.com/neptune-ai/neptune-client) API ([#6867](https://github.com/PyTorchLightning/pytorch-lightning/pull/6867)).
@@ -225,9 +228,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
225228
- Executing the `optimizer_closure` is now required when overriding the `optimizer_step` hook ([#9360](https://github.com/PyTorchLightning/pytorch-lightning/pull/9360))
226229

227230

228-
- Removed `TrainerProperties` mixin and moved property definitions directly into `Trainer` ([#9495](https://github.com/PyTorchLightning/pytorch-lightning/pull/9495))
229-
230-
231231
- Changed logging of `LightningModule` and `LightningDataModule` hyperparameters to raise an exception only if there are colliding keys with different values ([#9496](https://github.com/PyTorchLightning/pytorch-lightning/pull/9496))
232232

233233

@@ -394,6 +394,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
394394
- Removed `call_configure_sharded_model_hook` property from `Accelerator` and `TrainingTypePlugin` ([#9612](https://github.com/PyTorchLightning/pytorch-lightning/pull/9612))
395395

396396

397+
- Removed `TrainerProperties` mixin and moved property definitions directly into `Trainer` ([#9495](https://github.com/PyTorchLightning/pytorch-lightning/pull/9495))
398+
399+
397400
### Fixed
398401

399402

benchmarks/test_basic_parity.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,7 @@ def vanilla_loop(cls_model, idx, device_type: str = "cuda", num_epochs=10):
151151

152152
def lightning_loop(cls_model, idx, device_type: str = "cuda", num_epochs=10):
153153
seed_everything(idx)
154+
torch.backends.cudnn.deterministic = True
154155

155156
model = cls_model()
156157
# init model parts
@@ -161,7 +162,6 @@ def lightning_loop(cls_model, idx, device_type: str = "cuda", num_epochs=10):
161162
weights_summary=None,
162163
gpus=1 if device_type == "cuda" else 0,
163164
checkpoint_callback=False,
164-
deterministic=True,
165165
logger=False,
166166
replace_sampler_ddp=False,
167167
)

pytorch_lightning/trainer/connectors/accelerator_connector.py

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -60,10 +60,6 @@
6060
TorchElasticEnvironment,
6161
)
6262
from pytorch_lightning.utilities import (
63-
_APEX_AVAILABLE,
64-
_HOROVOD_AVAILABLE,
65-
_IPU_AVAILABLE,
66-
_TPU_AVAILABLE,
6763
AMPType,
6864
device_parser,
6965
DeviceType,
@@ -74,6 +70,14 @@
7470
)
7571
from pytorch_lightning.utilities.enums import PrecisionType
7672
from pytorch_lightning.utilities.exceptions import MisconfigurationException
73+
from pytorch_lightning.utilities.imports import (
74+
_APEX_AVAILABLE,
75+
_HOROVOD_AVAILABLE,
76+
_IPU_AVAILABLE,
77+
_TORCH_GREATER_EQUAL_1_7,
78+
_TORCH_GREATER_EQUAL_1_8,
79+
_TPU_AVAILABLE,
80+
)
7781

7882
if _HOROVOD_AVAILABLE:
7983
import horovod.torch as hvd
@@ -96,7 +100,7 @@ def __init__(
96100
sync_batchnorm,
97101
benchmark,
98102
replace_sampler_ddp,
99-
deterministic,
103+
deterministic: bool,
100104
precision,
101105
amp_type,
102106
amp_level,
@@ -113,6 +117,7 @@ def __init__(
113117
f" Use `Trainer(accelerator={distributed_backend})` instead."
114118
)
115119
distributed_backend = distributed_backend or accelerator
120+
self._init_deterministic(deterministic)
116121

117122
self.num_processes = num_processes
118123
self.devices = devices
@@ -126,7 +131,6 @@ def __init__(
126131
self.sync_batchnorm = sync_batchnorm
127132
self.benchmark = benchmark
128133
self.replace_sampler_ddp = replace_sampler_ddp
129-
self.deterministic = deterministic
130134
self.precision = precision
131135
self.amp_type = amp_type.lower() if isinstance(amp_type, str) else None
132136
self.amp_level = amp_level
@@ -177,15 +181,22 @@ def __init__(
177181
# TODO: should this be moved to GPU accelerator?
178182
torch.backends.cudnn.benchmark = self.benchmark
179183

180-
# determinism for cudnn
181-
# TODO: should this be moved to GPU accelerator?
182-
torch.backends.cudnn.deterministic = deterministic
184+
self.replace_sampler_ddp = replace_sampler_ddp
185+
186+
def _init_deterministic(self, deterministic: bool) -> None:
187+
self.deterministic = deterministic
188+
if _TORCH_GREATER_EQUAL_1_8:
189+
torch.use_deterministic_algorithms(deterministic)
190+
elif _TORCH_GREATER_EQUAL_1_7:
191+
torch.set_deterministic(deterministic)
192+
else: # the minimum version Lightning supports is PyTorch 1.6
193+
torch._set_deterministic(deterministic)
183194
if deterministic:
184195
# fixing non-deterministic part of horovod
185196
# https://github.com/PyTorchLightning/pytorch-lightning/pull/1572/files#r420279383
186197
os.environ["HOROVOD_FUSION_THRESHOLD"] = str(0)
187-
188-
self.replace_sampler_ddp = replace_sampler_ddp
198+
# https://docs.nvidia.com/cuda/cublas/index.html#cublasApi_reproducibility
199+
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
189200

190201
def select_accelerator_type(self) -> None:
191202
if self.distributed_backend == "auto":

pytorch_lightning/trainer/trainer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,8 @@ def __init__(
222222
Default: ``os.getcwd()``.
223223
Can be remote file paths such as `s3://mybucket/path` or 'hdfs://path/'
224224
225-
deterministic: If true enables cudnn.deterministic.
225+
deterministic: If ``True``, sets whether PyTorch operations must use deterministic algorithms.
226+
Default: ``False``.
226227
227228
devices: Will be mapped to either `gpus`, `tpu_cores`, `num_processes` or `ipus`,
228229
based on the accelerator type.

tests/accelerators/test_common.py

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import tests.helpers.utils as tutils
1818
from pytorch_lightning import Trainer
19+
from pytorch_lightning.utilities.seed import seed_everything
1920
from tests.accelerators.test_dp import CustomClassificationModelDP
2021
from tests.helpers.boring_model import BoringModel
2122
from tests.helpers.datamodules import ClassifDataModule
@@ -32,28 +33,20 @@
3233
)
3334
def test_evaluate(tmpdir, trainer_kwargs):
3435
tutils.set_random_master_port()
35-
36+
seed_everything(1)
3637
dm = ClassifDataModule()
3738
model = CustomClassificationModelDP()
3839
trainer = Trainer(
39-
default_root_dir=tmpdir,
40-
max_epochs=2,
41-
limit_train_batches=10,
42-
limit_val_batches=10,
43-
deterministic=True,
44-
**trainer_kwargs
40+
default_root_dir=tmpdir, max_epochs=2, limit_train_batches=10, limit_val_batches=10, **trainer_kwargs
4541
)
4642

4743
trainer.fit(model, datamodule=dm)
4844
assert "ckpt" in trainer.checkpoint_callback.best_model_path
4945

5046
old_weights = model.layer_0.weight.clone().detach().cpu()
5147

52-
result = trainer.validate(datamodule=dm)
53-
assert result[0]["val_acc"] > 0.55
54-
55-
result = trainer.test(datamodule=dm)
56-
assert result[0]["test_acc"] > 0.55
48+
trainer.validate(datamodule=dm)
49+
trainer.test(datamodule=dm)
5750

5851
# make sure weights didn't change
5952
new_weights = model.layer_0.weight.clone().detach().cpu()

tests/checkpointing/test_legacy_checkpoints.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,9 +83,9 @@ def test_resume_legacy_checkpoints(tmpdir, pl_version: str):
8383
callbacks=[es, stop],
8484
max_epochs=21,
8585
accumulate_grad_batches=2,
86-
deterministic=True,
8786
resume_from_checkpoint=path_ckpt,
8887
)
88+
torch.backends.cudnn.deterministic = True
8989
trainer.fit(model, datamodule=dm)
9090
res = trainer.test(model, datamodule=dm)
9191
assert res[0]["test_loss"] <= 0.7

tests/conftest.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import torch.distributed
2323

2424
from pytorch_lightning.plugins.environments.lightning_environment import find_free_network_port
25+
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_7, _TORCH_GREATER_EQUAL_1_8
2526
from tests import _PATH_DATASETS
2627

2728

@@ -52,6 +53,7 @@ def restore_env_variables():
5253
os.environ.update(env_backup)
5354
# these are currently known leakers - ideally these would not be allowed
5455
allowlist = {
56+
"CUBLAS_WORKSPACE_CONFIG", # enabled with deterministic flag
5557
"CUDA_DEVICE_ORDER",
5658
"LOCAL_RANK",
5759
"NODE_RANK",
@@ -87,6 +89,18 @@ def teardown_process_group():
8789
torch.distributed.destroy_process_group()
8890

8991

92+
@pytest.fixture(scope="function", autouse=True)
93+
def reset_deterministic_algorithm():
94+
"""Ensures that torch determinism settings are reset before the next test runs."""
95+
yield
96+
if _TORCH_GREATER_EQUAL_1_8:
97+
torch.use_deterministic_algorithms(False)
98+
elif _TORCH_GREATER_EQUAL_1_7:
99+
torch.set_deterministic(False)
100+
else: # the minimum version Lightning supports is PyTorch 1.6
101+
torch._set_deterministic(False)
102+
103+
90104
@pytest.fixture
91105
def tmpdir_server(tmpdir):
92106
if sys.version_info >= (3, 7):

tests/models/test_horovod.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,6 @@ def test_horovod_cpu(tmpdir):
7878
limit_train_batches=0.4,
7979
limit_val_batches=0.2,
8080
accelerator="horovod",
81-
deterministic=True,
8281
)
8382
_run_horovod(trainer_options)
8483

@@ -96,7 +95,6 @@ def test_horovod_cpu_clip_grad_by_value(tmpdir):
9695
limit_train_batches=0.4,
9796
limit_val_batches=0.2,
9897
accelerator="horovod",
99-
deterministic=True,
10098
)
10199
_run_horovod(trainer_options)
102100

@@ -112,7 +110,6 @@ def test_horovod_cpu_implicit(tmpdir):
112110
max_epochs=1,
113111
limit_train_batches=0.4,
114112
limit_val_batches=0.2,
115-
deterministic=True,
116113
)
117114
_run_horovod(trainer_options)
118115

@@ -129,7 +126,6 @@ def test_horovod_multi_gpu(tmpdir):
129126
limit_train_batches=0.4,
130127
limit_val_batches=0.2,
131128
gpus=2,
132-
deterministic=True,
133129
accelerator="horovod",
134130
)
135131
_run_horovod(trainer_options, on_gpu=True)
@@ -148,7 +144,6 @@ def test_horovod_multi_gpu_grad_by_value(tmpdir):
148144
limit_train_batches=0.4,
149145
limit_val_batches=0.2,
150146
gpus=2,
151-
deterministic=True,
152147
accelerator="horovod",
153148
)
154149
_run_horovod(trainer_options, on_gpu=True)
@@ -170,7 +165,6 @@ def test_horovod_apex(tmpdir):
170165
limit_train_batches=0.4,
171166
limit_val_batches=0.2,
172167
gpus=2,
173-
deterministic=True,
174168
accelerator="horovod",
175169
amp_backend="apex",
176170
precision=16,
@@ -190,7 +184,6 @@ def test_horovod_amp(tmpdir):
190184
limit_train_batches=0.4,
191185
limit_val_batches=0.2,
192186
gpus=2,
193-
deterministic=True,
194187
accelerator="horovod",
195188
amp_backend="native",
196189
precision=16,
@@ -210,7 +203,6 @@ def test_horovod_gather(tmpdir):
210203
limit_train_batches=0.4,
211204
limit_val_batches=0.2,
212205
gpus=2,
213-
deterministic=True,
214206
accelerator="horovod",
215207
)
216208
_run_horovod(trainer_options, on_gpu=True)
@@ -236,7 +228,6 @@ def validation_step(self, batch, *args, **kwargs):
236228
limit_train_batches=0.4,
237229
limit_val_batches=0.2,
238230
gpus=1,
239-
deterministic=True,
240231
accelerator="horovod",
241232
)
242233
tpipes.run_model_test_without_loggers(trainer_options, model)
@@ -253,7 +244,6 @@ def test_horovod_multi_optimizer(tmpdir):
253244
max_epochs=1,
254245
limit_train_batches=0.4,
255246
limit_val_batches=0.2,
256-
deterministic=True,
257247
accelerator="horovod",
258248
)
259249
trainer.fit(model)

tests/overrides/test_data_parallel.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -85,8 +85,11 @@ def training_step(self, batch, batch_idx):
8585
return {"loss": loss}
8686

8787
model = TestModel()
88-
model.trainer = Mock()
89-
model.trainer.state.stage = RunningStage.TRAINING
88+
trainer = MagicMock()
89+
trainer.state.stage = RunningStage.TRAINING
90+
trainer.accelerator_connector._init_deterministic(False)
91+
92+
model.trainer = trainer
9093
batch = torch.rand(2, 32).cuda()
9194
batch_idx = 0
9295

@@ -123,8 +126,10 @@ def training_step(self, batch, batch_idx):
123126
return output
124127

125128
model = TestModel().to(device)
126-
model.trainer = Mock()
127-
model.trainer.state.stage = RunningStage.TRAINING
129+
trainer = MagicMock()
130+
trainer.state.stage = RunningStage.TRAINING
131+
trainer.accelerator_connector._init_deterministic(False)
132+
model.trainer = trainer
128133
batch = torch.rand(2, 32).to(device)
129134
batch_idx = 0
130135

0 commit comments

Comments
 (0)