Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `pl_legacy_patch` load utility for loading old checkpoints that have pickled legacy Lightning attributes ([#9166](https://github.com/PyTorchLightning/pytorch-lightning/pull/9166))


- Added support for `torch.use_deterministic_algorithms` ([#9121](https://github.com/PyTorchLightning/pytorch-lightning/pull/9121))


### Changed

- `pytorch_lightning.loggers.neptune.NeptuneLogger` is now consistent with new [neptune-client](https://github.com/neptune-ai/neptune-client) API ([#6867](https://github.com/PyTorchLightning/pytorch-lightning/pull/6867)).
Expand Down Expand Up @@ -225,9 +228,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Executing the `optimizer_closure` is now required when overriding the `optimizer_step` hook ([#9360](https://github.com/PyTorchLightning/pytorch-lightning/pull/9360))


- Removed `TrainerProperties` mixin and moved property definitions directly into `Trainer` ([#9495](https://github.com/PyTorchLightning/pytorch-lightning/pull/9495))


- Changed logging of `LightningModule` and `LightningDataModule` hyperparameters to raise an exception only if there are colliding keys with different values ([#9496](https://github.com/PyTorchLightning/pytorch-lightning/pull/9496))


Expand Down Expand Up @@ -394,6 +394,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Removed `call_configure_sharded_model_hook` property from `Accelerator` and `TrainingTypePlugin` ([#9612](https://github.com/PyTorchLightning/pytorch-lightning/pull/9612))


- Removed `TrainerProperties` mixin and moved property definitions directly into `Trainer` ([#9495](https://github.com/PyTorchLightning/pytorch-lightning/pull/9495))


### Fixed


Expand Down
2 changes: 1 addition & 1 deletion benchmarks/test_basic_parity.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ def vanilla_loop(cls_model, idx, device_type: str = "cuda", num_epochs=10):

def lightning_loop(cls_model, idx, device_type: str = "cuda", num_epochs=10):
seed_everything(idx)
torch.backends.cudnn.deterministic = True

model = cls_model()
# init model parts
Expand All @@ -161,7 +162,6 @@ def lightning_loop(cls_model, idx, device_type: str = "cuda", num_epochs=10):
weights_summary=None,
gpus=1 if device_type == "cuda" else 0,
checkpoint_callback=False,
deterministic=True,
logger=False,
replace_sampler_ddp=False,
)
Expand Down
33 changes: 22 additions & 11 deletions pytorch_lightning/trainer/connectors/accelerator_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,6 @@
TorchElasticEnvironment,
)
from pytorch_lightning.utilities import (
_APEX_AVAILABLE,
_HOROVOD_AVAILABLE,
_IPU_AVAILABLE,
_TPU_AVAILABLE,
AMPType,
device_parser,
DeviceType,
Expand All @@ -74,6 +70,14 @@
)
from pytorch_lightning.utilities.enums import PrecisionType
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import (
_APEX_AVAILABLE,
_HOROVOD_AVAILABLE,
_IPU_AVAILABLE,
_TORCH_GREATER_EQUAL_1_7,
_TORCH_GREATER_EQUAL_1_8,
_TPU_AVAILABLE,
)

if _HOROVOD_AVAILABLE:
import horovod.torch as hvd
Expand All @@ -96,7 +100,7 @@ def __init__(
sync_batchnorm,
benchmark,
replace_sampler_ddp,
deterministic,
deterministic: bool,
precision,
amp_type,
amp_level,
Expand All @@ -113,6 +117,7 @@ def __init__(
f" Use `Trainer(accelerator={distributed_backend})` instead."
)
distributed_backend = distributed_backend or accelerator
self._init_deterministic(deterministic)

self.num_processes = num_processes
self.devices = devices
Expand All @@ -126,7 +131,6 @@ def __init__(
self.sync_batchnorm = sync_batchnorm
self.benchmark = benchmark
self.replace_sampler_ddp = replace_sampler_ddp
self.deterministic = deterministic
self.precision = precision
self.amp_type = amp_type.lower() if isinstance(amp_type, str) else None
self.amp_level = amp_level
Expand Down Expand Up @@ -177,15 +181,22 @@ def __init__(
# TODO: should this be moved to GPU accelerator?
torch.backends.cudnn.benchmark = self.benchmark

# determinism for cudnn
# TODO: should this be moved to GPU accelerator?
torch.backends.cudnn.deterministic = deterministic
self.replace_sampler_ddp = replace_sampler_ddp

def _init_deterministic(self, deterministic: bool) -> None:
self.deterministic = deterministic
if _TORCH_GREATER_EQUAL_1_8:
torch.use_deterministic_algorithms(deterministic)
elif _TORCH_GREATER_EQUAL_1_7:
torch.set_deterministic(deterministic)
else: # the minimum version Lightning supports is PyTorch 1.6
torch._set_deterministic(deterministic)
if deterministic:
# fixing non-deterministic part of horovod
# https://github.com/PyTorchLightning/pytorch-lightning/pull/1572/files#r420279383
os.environ["HOROVOD_FUSION_THRESHOLD"] = str(0)

self.replace_sampler_ddp = replace_sampler_ddp
# https://docs.nvidia.com/cuda/cublas/index.html#cublasApi_reproducibility
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"

def select_accelerator_type(self) -> None:
if self.distributed_backend == "auto":
Expand Down
3 changes: 2 additions & 1 deletion pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,8 @@ def __init__(
Default: ``os.getcwd()``.
Can be remote file paths such as `s3://mybucket/path` or 'hdfs://path/'

deterministic: If true enables cudnn.deterministic.
deterministic: If ``True``, sets whether PyTorch operations must use deterministic algorithms.
Default: ``False``.

devices: Will be mapped to either `gpus`, `tpu_cores`, `num_processes` or `ipus`,
based on the accelerator type.
Expand Down
17 changes: 5 additions & 12 deletions tests/accelerators/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import tests.helpers.utils as tutils
from pytorch_lightning import Trainer
from pytorch_lightning.utilities.seed import seed_everything
from tests.accelerators.test_dp import CustomClassificationModelDP
from tests.helpers.boring_model import BoringModel
from tests.helpers.datamodules import ClassifDataModule
Expand All @@ -32,28 +33,20 @@
)
def test_evaluate(tmpdir, trainer_kwargs):
tutils.set_random_master_port()

seed_everything(1)
dm = ClassifDataModule()
model = CustomClassificationModelDP()
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=2,
limit_train_batches=10,
limit_val_batches=10,
deterministic=True,
**trainer_kwargs
default_root_dir=tmpdir, max_epochs=2, limit_train_batches=10, limit_val_batches=10, **trainer_kwargs
)

trainer.fit(model, datamodule=dm)
assert "ckpt" in trainer.checkpoint_callback.best_model_path

old_weights = model.layer_0.weight.clone().detach().cpu()

result = trainer.validate(datamodule=dm)
assert result[0]["val_acc"] > 0.55

result = trainer.test(datamodule=dm)
assert result[0]["test_acc"] > 0.55
trainer.validate(datamodule=dm)
trainer.test(datamodule=dm)

# make sure weights didn't change
new_weights = model.layer_0.weight.clone().detach().cpu()
Expand Down
2 changes: 1 addition & 1 deletion tests/checkpointing/test_legacy_checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,9 @@ def test_resume_legacy_checkpoints(tmpdir, pl_version: str):
callbacks=[es, stop],
max_epochs=21,
accumulate_grad_batches=2,
deterministic=True,
resume_from_checkpoint=path_ckpt,
)
torch.backends.cudnn.deterministic = True
trainer.fit(model, datamodule=dm)
res = trainer.test(model, datamodule=dm)
assert res[0]["test_loss"] <= 0.7
Expand Down
14 changes: 14 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import torch.distributed

from pytorch_lightning.plugins.environments.lightning_environment import find_free_network_port
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_7, _TORCH_GREATER_EQUAL_1_8
from tests import _PATH_DATASETS


Expand Down Expand Up @@ -52,6 +53,7 @@ def restore_env_variables():
os.environ.update(env_backup)
# these are currently known leakers - ideally these would not be allowed
allowlist = {
"CUBLAS_WORKSPACE_CONFIG", # enabled with deterministic flag
"CUDA_DEVICE_ORDER",
"LOCAL_RANK",
"NODE_RANK",
Expand Down Expand Up @@ -87,6 +89,18 @@ def teardown_process_group():
torch.distributed.destroy_process_group()


@pytest.fixture(scope="function", autouse=True)
def reset_deterministic_algorithm():
"""Ensures that torch determinism settings are reset before the next test runs."""
yield
if _TORCH_GREATER_EQUAL_1_8:
torch.use_deterministic_algorithms(False)
elif _TORCH_GREATER_EQUAL_1_7:
torch.set_deterministic(False)
else: # the minimum version Lightning supports is PyTorch 1.6
torch._set_deterministic(False)


@pytest.fixture
def tmpdir_server(tmpdir):
if sys.version_info >= (3, 7):
Expand Down
10 changes: 0 additions & 10 deletions tests/models/test_horovod.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,6 @@ def test_horovod_cpu(tmpdir):
limit_train_batches=0.4,
limit_val_batches=0.2,
accelerator="horovod",
deterministic=True,
)
_run_horovod(trainer_options)

Expand All @@ -96,7 +95,6 @@ def test_horovod_cpu_clip_grad_by_value(tmpdir):
limit_train_batches=0.4,
limit_val_batches=0.2,
accelerator="horovod",
deterministic=True,
)
_run_horovod(trainer_options)

Expand All @@ -112,7 +110,6 @@ def test_horovod_cpu_implicit(tmpdir):
max_epochs=1,
limit_train_batches=0.4,
limit_val_batches=0.2,
deterministic=True,
)
_run_horovod(trainer_options)

Expand All @@ -129,7 +126,6 @@ def test_horovod_multi_gpu(tmpdir):
limit_train_batches=0.4,
limit_val_batches=0.2,
gpus=2,
deterministic=True,
accelerator="horovod",
)
_run_horovod(trainer_options, on_gpu=True)
Expand All @@ -148,7 +144,6 @@ def test_horovod_multi_gpu_grad_by_value(tmpdir):
limit_train_batches=0.4,
limit_val_batches=0.2,
gpus=2,
deterministic=True,
accelerator="horovod",
)
_run_horovod(trainer_options, on_gpu=True)
Expand All @@ -170,7 +165,6 @@ def test_horovod_apex(tmpdir):
limit_train_batches=0.4,
limit_val_batches=0.2,
gpus=2,
deterministic=True,
accelerator="horovod",
amp_backend="apex",
precision=16,
Expand All @@ -190,7 +184,6 @@ def test_horovod_amp(tmpdir):
limit_train_batches=0.4,
limit_val_batches=0.2,
gpus=2,
deterministic=True,
accelerator="horovod",
amp_backend="native",
precision=16,
Expand All @@ -210,7 +203,6 @@ def test_horovod_gather(tmpdir):
limit_train_batches=0.4,
limit_val_batches=0.2,
gpus=2,
deterministic=True,
accelerator="horovod",
)
_run_horovod(trainer_options, on_gpu=True)
Expand All @@ -236,7 +228,6 @@ def validation_step(self, batch, *args, **kwargs):
limit_train_batches=0.4,
limit_val_batches=0.2,
gpus=1,
deterministic=True,
accelerator="horovod",
)
tpipes.run_model_test_without_loggers(trainer_options, model)
Expand All @@ -253,7 +244,6 @@ def test_horovod_multi_optimizer(tmpdir):
max_epochs=1,
limit_train_batches=0.4,
limit_val_batches=0.2,
deterministic=True,
accelerator="horovod",
)
trainer.fit(model)
Expand Down
13 changes: 9 additions & 4 deletions tests/overrides/test_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,11 @@ def training_step(self, batch, batch_idx):
return {"loss": loss}

model = TestModel()
model.trainer = Mock()
model.trainer.state.stage = RunningStage.TRAINING
trainer = MagicMock()
trainer.state.stage = RunningStage.TRAINING
trainer.accelerator_connector._init_deterministic(False)

model.trainer = trainer
batch = torch.rand(2, 32).cuda()
batch_idx = 0

Expand Down Expand Up @@ -123,8 +126,10 @@ def training_step(self, batch, batch_idx):
return output

model = TestModel().to(device)
model.trainer = Mock()
model.trainer.state.stage = RunningStage.TRAINING
trainer = MagicMock()
trainer.state.stage = RunningStage.TRAINING
trainer.accelerator_connector._init_deterministic(False)
model.trainer = trainer
batch = torch.rand(2, 32).to(device)
batch_idx = 0

Expand Down