Skip to content

Commit e15a664

Browse files
four4fishananthsubjustusschock
authored
Add back deterministic support in accelerator_connector (#11999)
Co-authored-by: ananthsub <[email protected]> Co-authored-by: Justus Schock <[email protected]>
1 parent 6bc0e1d commit e15a664

File tree

3 files changed

+33
-1
lines changed

3 files changed

+33
-1
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
299299

300300
- Changed default logger name to `lightning_logs` for consistency ([#11762](https://github.com/PyTorchLightning/pytorch-lightning/pull/11762))
301301

302+
303+
- Rewrote `accelerator_connector` ([#11448](https://github.com/PyTorchLightning/pytorch-lightning/pull/11448))
304+
302305
### Deprecated
303306

304307
- Deprecated `training_type_plugin` property in favor of `strategy` in `Trainer` and updated the references ([#11141](https://github.com/PyTorchLightning/pytorch-lightning/pull/11141))

pytorch_lightning/trainer/connectors/accelerator_connector.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,12 @@
7474
rank_zero_warn,
7575
)
7676
from pytorch_lightning.utilities.exceptions import MisconfigurationException
77-
from pytorch_lightning.utilities.imports import _HOROVOD_AVAILABLE, _IPU_AVAILABLE, _TPU_AVAILABLE
77+
from pytorch_lightning.utilities.imports import (
78+
_HOROVOD_AVAILABLE,
79+
_IPU_AVAILABLE,
80+
_TORCH_GREATER_EQUAL_1_8,
81+
_TPU_AVAILABLE,
82+
)
7883

7984
log = logging.getLogger(__name__)
8085

@@ -141,6 +146,7 @@ def __init__(
141146
torch.backends.cudnn.benchmark = benchmark
142147
self.replace_sampler_ddp = replace_sampler_ddp
143148
self.sync_batchnorm = sync_batchnorm
149+
self._init_deterministic(deterministic)
144150

145151
# 1. Parsing flags
146152
# Get registered strategies, built-in accelerators and precision plugins
@@ -196,6 +202,20 @@ def __init__(
196202
# 6. Instantiate Strategy - Part 2
197203
self._lazy_init_strategy()
198204

205+
def _init_deterministic(self, deterministic: bool) -> None:
206+
self.deterministic = deterministic
207+
if _TORCH_GREATER_EQUAL_1_8:
208+
torch.use_deterministic_algorithms(deterministic)
209+
else:
210+
torch.set_deterministic(deterministic)
211+
if deterministic:
212+
# fixing non-deterministic part of horovod
213+
# https://github.com/PyTorchLightning/pytorch-lightning/pull/1572/files#r420279383
214+
os.environ["HOROVOD_FUSION_THRESHOLD"] = "0"
215+
216+
# https://docs.nvidia.com/cuda/cublas/index.html#cublasApi_reproducibility
217+
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
218+
199219
def _check_config_and_set_final_flags(
200220
self,
201221
strategy: Optional[Union[str, Strategy]],

tests/accelerators/test_accelerator_connector.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -947,3 +947,12 @@ def test_passing_zero_and_empty_list_to_devices_flag():
947947

948948
with pytest.warns(UserWarning, match=r"switching to `cpu` accelerator"):
949949
Trainer(accelerator="gpu", devices=[])
950+
951+
952+
@pytest.mark.parametrize("deterministic", [True, False])
953+
def test_deterministic_init(deterministic):
954+
trainer = Trainer(accelerator="auto", deterministic=deterministic)
955+
assert trainer._accelerator_connector.deterministic == deterministic
956+
if deterministic:
957+
assert os.environ.get("CUBLAS_WORKSPACE_CONFIG") == ":4096:8"
958+
assert os.environ.get("HOROVOD_FUSION_THRESHOLD") == "0"

0 commit comments

Comments
 (0)