From 0794f8cf1e6b4c7d9847a071f787742abe643f71 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 19 Feb 2021 18:00:27 +0100 Subject: [PATCH 1/4] Fix amp autocast (#6080) * precision fixes * add amp test model * fix test * revert * move assert to training step * fix test * fix test * remove unrelated changes * add changelog * remove unused import --- CHANGELOG.md | 7 ++++++ .../plugins/precision/native_amp.py | 3 ++- tests/models/test_amp.py | 22 +++++++++++++------ 3 files changed, 24 insertions(+), 8 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b80afe7b24d0f..c1e6068f83fde 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,13 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). +## [1.2.1] - 2021-02-23 + +### Fixed + +- Fixed incorrect yield logic for the amp autocast context manager ([#6080](https://github.com/PyTorchLightning/pytorch-lightning/pull/6080)) + + ## [1.2.0] - 2021-02-18 ### Added diff --git a/pytorch_lightning/plugins/precision/native_amp.py b/pytorch_lightning/plugins/precision/native_amp.py index 60c0f5f84626f..94e6cf376b03a 100644 --- a/pytorch_lightning/plugins/precision/native_amp.py +++ b/pytorch_lightning/plugins/precision/native_amp.py @@ -91,4 +91,5 @@ def post_optimizer_step(self, optimizer: Optimizer, optimizer_idx: int) -> None: @contextmanager def train_step_context(self) -> Generator[autocast, None, None]: """Enable autocast context""" - yield torch.cuda.amp.autocast() + with torch.cuda.amp.autocast(): + yield diff --git a/tests/models/test_amp.py b/tests/models/test_amp.py index 2dd6c9d997dbf..53ec32764f3ed 100644 --- a/tests/models/test_amp.py +++ b/tests/models/test_amp.py @@ -27,6 +27,16 @@ from tests.helpers import BoringModel +class AMPTestModel(BoringModel): + + def training_step(self, batch, batch_idx): + assert torch.is_autocast_enabled() + output = self(batch) + assert output.dtype == torch.float16 + loss = self.loss(batch, output) + return {"loss": loss} + + @pytest.mark.skip(reason='dp + amp not supported currently') # TODO @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine") def test_amp_single_gpu_dp(tmpdir): @@ -41,7 +51,7 @@ def test_amp_single_gpu_dp(tmpdir): precision=16, ) - model = BoringModel() + model = AMPTestModel() # tutils.run_model_test(trainer_options, model) trainer.fit(model) @@ -60,10 +70,9 @@ def test_amp_single_gpu_ddp_spawn(tmpdir): precision=16, ) - model = BoringModel() + model = AMPTestModel() # tutils.run_model_test(trainer_options, model) trainer.fit(model) - assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" @@ -81,7 +90,7 @@ def test_amp_multi_gpu_dp(tmpdir): precision=16, ) - model = BoringModel() + model = AMPTestModel() # tutils.run_model_test(trainer_options, model) trainer.fit(model) @@ -100,10 +109,9 @@ def test_amp_multi_gpu_ddp_spawn(tmpdir): precision=16, ) - model = BoringModel() + model = AMPTestModel() # tutils.run_model_test(trainer_options, model) trainer.fit(model) - assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" @@ -122,7 +130,7 @@ def test_amp_gpu_ddp_slurm_managed(tmpdir): # simulate setting slurm flags tutils.set_random_master_port() - model = BoringModel() + model = AMPTestModel() # exp file to get meta logger = tutils.get_default_logger(tmpdir) From 2164989bd4c09a36463799ff3974467f8b9176bc Mon Sep 17 00:00:00 2001 From: Sean Naren Date: Sat, 20 Feb 2021 12:58:54 +0000 Subject: [PATCH 2/4] [Hot Fix] Give priority to plugins to set distributed mode, and then accelerator (#6089) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Give priority to plugins to set distributed mode, and then accelerator * Add CHANGELOG.md * Update CHANGELOG.md * Remove very scary line * Ensure we set cluster environment after slurm configured if necessary * Simplify the fix with a reset Co-authored-by: Carlos MocholĂ­ --- CHANGELOG.md | 3 +++ .../connectors/accelerator_connector.py | 4 +++- .../test_accelerator_connector.py | 24 ++++++++++++++++++- 3 files changed, 29 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c1e6068f83fde..1916659068341 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed incorrect yield logic for the amp autocast context manager ([#6080](https://github.com/PyTorchLightning/pytorch-lightning/pull/6080)) +- Fixed priority of plugin/accelerator when setting distributed mode ([#6089](https://github.com/PyTorchLightning/pytorch-lightning/pull/6089)) + + ## [1.2.0] - 2021-02-18 ### Added diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index d32970d61fa9b..7021081d6cc90 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -163,6 +163,9 @@ def handle_given_plugins( for plug in plugins: if isinstance(plug, str): + # Reset the distributed type as the user has overridden training type + # via the plugins argument + self._distrib_type = None self.set_distributed_mode(plug) elif isinstance(plug, TrainingTypePlugin): @@ -196,7 +199,6 @@ def handle_given_plugins( ) self._training_type_plugin = training_type - self._training_type_plugin = self.training_type_plugin self._precision_plugin = precision self._cluster_environment = cluster_environment or self.select_cluster_environment() diff --git a/tests/accelerators/test_accelerator_connector.py b/tests/accelerators/test_accelerator_connector.py index 76d4a597d8ecb..82b631807c8e9 100644 --- a/tests/accelerators/test_accelerator_connector.py +++ b/tests/accelerators/test_accelerator_connector.py @@ -23,7 +23,14 @@ from pytorch_lightning.accelerators.cpu import CPUAccelerator from pytorch_lightning.accelerators.gpu import GPUAccelerator from pytorch_lightning.callbacks import Callback -from pytorch_lightning.plugins import DDP2Plugin, DDPPlugin, DDPSpawnPlugin, PrecisionPlugin, SingleDevicePlugin +from pytorch_lightning.plugins import ( + DDP2Plugin, + DDPPlugin, + DDPShardedPlugin, + DDPSpawnPlugin, + PrecisionPlugin, + SingleDevicePlugin, +) from pytorch_lightning.plugins.environments import ClusterEnvironment, SLURMEnvironment, TorchElasticEnvironment from tests.helpers.boring_model import BoringModel @@ -378,3 +385,18 @@ def on_fit_start(self, trainer, pl_module): with pytest.raises(SystemExit): trainer.fit(model) + + +@pytest.mark.parametrize( + ["accelerator", "plugin"], + [('ddp_spawn', 'ddp_sharded'), (None, 'ddp_sharded')], +) +def test_plugin_accelerator_choice(accelerator, plugin): + """ + Ensure that when a plugin and accelerator is passed in, that the plugin takes precedent. + """ + trainer = Trainer(accelerator=accelerator, plugins=plugin, num_processes=2) + assert isinstance(trainer.accelerator.training_type_plugin, DDPShardedPlugin) + + trainer = Trainer(plugins=plugin, num_processes=2) + assert isinstance(trainer.accelerator.training_type_plugin, DDPShardedPlugin) From 0d0fe396ca52fff066af9cb01972b5e311035c4d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 22 Feb 2021 01:02:31 +0100 Subject: [PATCH 3/4] fix amp/apex misconfiguration error for cpu (#6107) * fix weird test * fix apex plugin test * fix raise * cpu test * fix type * add changelog --- CHANGELOG.md | 5 +- pytorch_lightning/accelerators/cpu.py | 2 +- tests/accelerators/test_cpu.py | 21 +++++++++ tests/plugins/test_amp_plugin.py | 67 ++------------------------- tests/plugins/test_apex_plugin.py | 40 ++++------------ 5 files changed, 38 insertions(+), 97 deletions(-) create mode 100644 tests/accelerators/test_cpu.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 1916659068341..212de8d2ddf28 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,11 +15,14 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed priority of plugin/accelerator when setting distributed mode ([#6089](https://github.com/PyTorchLightning/pytorch-lightning/pull/6089)) +- Fixed error message for AMP + CPU incompatibility ([#6107](https://github.com/PyTorchLightning/pytorch-lightning/pull/6107)) + + ## [1.2.0] - 2021-02-18 ### Added -- Added `DataType`, `AverageMethod` and `MDMCAverageMethod` enum in metrics ([#5657](https://github.com/PyTorchLightning/pytorch-lightning/pull/5689) +- Added `DataType`, `AverageMethod` and `MDMCAverageMethod` enum in metrics ([#5657](https://github.com/PyTorchLightning/pytorch-lightning/pull/5689)) - Added support for summarized model total params size in megabytes ([#5590](https://github.com/PyTorchLightning/pytorch-lightning/pull/5590)) - Added support for multiple train loaders ([#1959](https://github.com/PyTorchLightning/pytorch-lightning/pull/1959)) - Added `Accuracy` metric now generalizes to Top-k accuracy for (multi-dimensional) multi-class inputs using the `top_k` parameter ([#4838](https://github.com/PyTorchLightning/pytorch-lightning/pull/4838)) diff --git a/pytorch_lightning/accelerators/cpu.py b/pytorch_lightning/accelerators/cpu.py index 7c79c470001c3..83389b53f5bcb 100644 --- a/pytorch_lightning/accelerators/cpu.py +++ b/pytorch_lightning/accelerators/cpu.py @@ -7,7 +7,7 @@ class CPUAccelerator(Accelerator): def setup(self, trainer, model): if isinstance(self.precision_plugin, MixedPrecisionPlugin): - MisconfigurationException("amp + cpu is not supported. Please use a GPU option") + raise MisconfigurationException("amp + cpu is not supported. Please use a GPU option") if "cpu" not in str(self.root_device): raise MisconfigurationException(f"Device should be CPU, got {self.root_device} instead") diff --git a/tests/accelerators/test_cpu.py b/tests/accelerators/test_cpu.py new file mode 100644 index 0000000000000..5c97527be048b --- /dev/null +++ b/tests/accelerators/test_cpu.py @@ -0,0 +1,21 @@ +from unittest.mock import Mock + +import pytest +import torch + +from pytorch_lightning.accelerators import CPUAccelerator +from pytorch_lightning.plugins import SingleDevicePlugin +from pytorch_lightning.plugins.precision import MixedPrecisionPlugin +from pytorch_lightning.utilities.exceptions import MisconfigurationException + + +def test_unsupported_precision_plugins(): + """ Test error messages are raised for unsupported precision plugins with CPU. """ + trainer = Mock() + model = Mock() + accelerator = CPUAccelerator( + training_type_plugin=SingleDevicePlugin(torch.device("cpu")), + precision_plugin=MixedPrecisionPlugin() + ) + with pytest.raises(MisconfigurationException, match=r"amp \+ cpu is not supported."): + accelerator.setup(trainer=trainer, model=model) diff --git a/tests/plugins/test_amp_plugin.py b/tests/plugins/test_amp_plugin.py index 80a06b0072e1e..8236a0990335a 100644 --- a/tests/plugins/test_amp_plugin.py +++ b/tests/plugins/test_amp_plugin.py @@ -5,10 +5,8 @@ import torch from pytorch_lightning import Trainer -from pytorch_lightning.callbacks import Callback from pytorch_lightning.plugins import NativeMixedPrecisionPlugin from pytorch_lightning.utilities import _NATIVE_AMP_AVAILABLE -from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers.boring_model import BoringModel @@ -25,78 +23,21 @@ ) @mock.patch('torch.cuda.device_count', return_value=2) @pytest.mark.parametrize( - ['ddp_backend', 'gpus', 'num_processes'], - [('ddp_cpu', None, 2), ('ddp', 2, 0), ('ddp2', 2, 0), ('ddp_spawn', 2, 0)], + ['ddp_backend', 'gpus'], + [('ddp', 2), ('ddp2', 2), ('ddp_spawn', 2)], ) -def on_fit_start(tmpdir, ddp_backend, gpus, num_processes): - - class CB(Callback): - - def on_fit_start(self, trainer, pl_module): - assert isinstance(trainer.precision_plugin, NativeMixedPrecisionPlugin) - raise SystemExit() - - def train(): - model = BoringModel() - trainer = Trainer( - fast_dev_run=True, - precision=16, - amp_backend='native', - gpus=gpus, - num_processes=num_processes, - accelerator=ddp_backend, - callbacks=[CB()], - ) - trainer.fit(model) - - if ddp_backend == "ddp_cpu": - with pytest.raises(MisconfigurationException, match="MP is only available on GPU"): - train() - else: - with pytest.raises(SystemExit): - train() - - -@pytest.mark.skipif(not _NATIVE_AMP_AVAILABLE, reason="Minimal PT version is set to 1.6") -@mock.patch.dict( - os.environ, { - "CUDA_VISIBLE_DEVICES": "0,1", - "SLURM_NTASKS": "2", - "SLURM_JOB_NAME": "SOME_NAME", - "SLURM_NODEID": "0", - "LOCAL_RANK": "0", - "SLURM_LOCALID": "0" - } -) -@mock.patch('torch.cuda.device_count', return_value=2) -@pytest.mark.parametrize( - ['ddp_backend', 'gpus', 'num_processes'], - [('ddp_cpu', None, 2), ('ddp', 2, 0), ('ddp2', 2, 0), ('ddp_spawn', 2, 0)], -) -def test_amp_choice_custom_ddp_cpu(tmpdir, ddp_backend, gpus, num_processes): +def test_amp_choice_custom_ddp_cpu(device_count_mock, ddp_backend, gpus): class MyNativeAMP(NativeMixedPrecisionPlugin): pass - class CB(Callback): - - def on_fit_start(self, trainer, pl_module): - assert isinstance(trainer.precision_plugin, MyNativeAMP) - raise SystemExit() - - model = BoringModel() trainer = Trainer( - fast_dev_run=True, precision=16, amp_backend='native', - num_processes=num_processes, accelerator=ddp_backend, plugins=[MyNativeAMP()], - callbacks=[CB()], ) - - with pytest.raises(SystemExit): - trainer.fit(model) + assert isinstance(trainer.precision_plugin, MyNativeAMP) class GradientUnscaleBoringModel(BoringModel): diff --git a/tests/plugins/test_apex_plugin.py b/tests/plugins/test_apex_plugin.py index 91d42822db57b..dd6c3f266928b 100644 --- a/tests/plugins/test_apex_plugin.py +++ b/tests/plugins/test_apex_plugin.py @@ -4,10 +4,8 @@ import pytest from pytorch_lightning import Trainer -from pytorch_lightning.callbacks import Callback from pytorch_lightning.plugins import ApexMixedPrecisionPlugin from pytorch_lightning.utilities import _APEX_AVAILABLE -from tests.helpers.boring_model import BoringModel @pytest.mark.skipif(not _APEX_AVAILABLE, reason="test requires apex") @@ -23,30 +21,19 @@ ) @mock.patch('torch.cuda.device_count', return_value=2) @pytest.mark.parametrize( - ['ddp_backend', 'gpus', 'num_processes'], - [('ddp_cpu', None, 2), ('ddp', 2, 0), ('ddp2', 2, 0), ('ddp_spawn', 2, 0)], + ['ddp_backend', 'gpus'], + [('ddp', 2), ('ddp2', 2), ('ddp_spawn', 2)], ) -def test_amp_choice_default_ddp_cpu(tmpdir, ddp_backend, gpus, num_processes): +def test_amp_choice_default_ddp(mocked_device_count, ddp_backend, gpus): - class CB(Callback): - - def on_fit_start(self, trainer, pl_module): - assert isinstance(trainer.precision_plugin, ApexMixedPrecisionPlugin) - raise SystemExit() - - model = BoringModel() trainer = Trainer( fast_dev_run=True, precision=16, amp_backend='apex', gpus=gpus, - num_processes=num_processes, accelerator=ddp_backend, - callbacks=[CB()], ) - - with pytest.raises(SystemExit): - trainer.fit(model) + assert isinstance(trainer.precision_plugin, ApexMixedPrecisionPlugin) @pytest.mark.skipif(not _APEX_AVAILABLE, reason="test requires apex") @@ -62,31 +49,20 @@ def on_fit_start(self, trainer, pl_module): ) @mock.patch('torch.cuda.device_count', return_value=2) @pytest.mark.parametrize( - ['ddp_backend', 'gpus', 'num_processes'], - [('ddp_cpu', None, 2), ('ddp', 2, 0), ('ddp2', 2, 0), ('ddp_spawn', 2, 0)], + ['ddp_backend', 'gpus'], + [('ddp', 2), ('ddp2', 2), ('ddp_spawn', 2)], ) -def test_amp_choice_custom_ddp_cpu(tmpdir, ddp_backend, gpus, num_processes): +def test_amp_choice_custom_ddp(mocked_device_count, ddp_backend, gpus): class MyApexPlugin(ApexMixedPrecisionPlugin): pass - class CB(Callback): - - def on_fit_start(self, trainer, pl_module): - assert isinstance(trainer.precision_plugin, MyApexPlugin) - raise SystemExit() - - model = BoringModel() trainer = Trainer( fast_dev_run=True, precision=16, amp_backend='apex', gpus=gpus, - num_processes=num_processes, accelerator=ddp_backend, plugins=[MyApexPlugin(amp_level="O2")], - callbacks=[CB()], ) - - with pytest.raises(SystemExit): - trainer.fit(model) + assert isinstance(trainer.precision_plugin, MyApexPlugin) From 0ff0961f4c3e432e72db8e74194996e7b3ed386e Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 23 Feb 2021 18:30:36 +0100 Subject: [PATCH 4/4] Update version and remove CHANGELOG whitespace --- CHANGELOG.md | 4 ---- pytorch_lightning/__init__.py | 2 +- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 212de8d2ddf28..e6ea2a380927a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,11 +10,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed - Fixed incorrect yield logic for the amp autocast context manager ([#6080](https://github.com/PyTorchLightning/pytorch-lightning/pull/6080)) - - - Fixed priority of plugin/accelerator when setting distributed mode ([#6089](https://github.com/PyTorchLightning/pytorch-lightning/pull/6089)) - - - Fixed error message for AMP + CPU incompatibility ([#6107](https://github.com/PyTorchLightning/pytorch-lightning/pull/6107)) diff --git a/pytorch_lightning/__init__.py b/pytorch_lightning/__init__.py index b816a4e8aafb9..234b3220f55b6 100644 --- a/pytorch_lightning/__init__.py +++ b/pytorch_lightning/__init__.py @@ -5,7 +5,7 @@ import time _this_year = time.strftime("%Y") -__version__ = '1.2.0' +__version__ = '1.2.1' __author__ = 'William Falcon et al.' __author_email__ = 'waf2107@columbia.edu' __license__ = 'Apache-2.0'