diff --git a/CHANGELOG.md b/CHANGELOG.md index cc78de0f9c0c1..b6c58543aa271 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -72,7 +72,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Deprecated metrics in favor of `torchmetrics` ([#6505](https://github.com/PyTorchLightning/pytorch-lightning/pull/6505), [#6530](https://github.com/PyTorchLightning/pytorch-lightning/pull/6530), - + [#6547](https://github.com/PyTorchLightning/pytorch-lightning/pull/6547), [#6515](https://github.com/PyTorchLightning/pytorch-lightning/pull/6515), @@ -113,6 +113,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed +- Added Autocast in validation, test and predict modes for Native AMP ([#6565](https://github.com/PyTorchLightning/pytorch-lightning/pull/6565)) + - Made the `Plugin.reduce` method more consistent across all Plugins to reflect a mean-reduction by default ([#6011](https://github.com/PyTorchLightning/pytorch-lightning/pull/6011)) diff --git a/pytorch_lightning/plugins/precision/native_amp.py b/pytorch_lightning/plugins/precision/native_amp.py index dc822680bcbda..3c83945c8a1b7 100644 --- a/pytorch_lightning/plugins/precision/native_amp.py +++ b/pytorch_lightning/plugins/precision/native_amp.py @@ -103,3 +103,21 @@ def train_step_context(self) -> Generator[None, None, None]: """Enable autocast context""" with torch.cuda.amp.autocast(): yield + + @contextmanager + def val_step_context(self) -> Generator[None, None, None]: + """Enable autocast context""" + with torch.cuda.amp.autocast(): + yield + + @contextmanager + def test_step_context(self) -> Generator[None, None, None]: + """Enable autocast context""" + with torch.cuda.amp.autocast(): + yield + + @contextmanager + def predict_context(self) -> Generator[None, None, None]: + """Enable autocast context""" + with torch.cuda.amp.autocast(): + yield diff --git a/tests/models/test_amp.py b/tests/models/test_amp.py index 9853db342436b..0b9d6776c1aaa 100644 --- a/tests/models/test_amp.py +++ b/tests/models/test_amp.py @@ -17,24 +17,43 @@ import pytest import torch from torch import optim +from torch.utils.data import DataLoader import tests.helpers.utils as tutils from pytorch_lightning import Trainer from pytorch_lightning.plugins.environments import SLURMEnvironment from pytorch_lightning.trainer.states import TrainerState from pytorch_lightning.utilities.exceptions import MisconfigurationException -from tests.helpers import BoringModel +from tests.helpers import BoringModel, RandomDataset from tests.helpers.runif import RunIf class AMPTestModel(BoringModel): - def training_step(self, batch, batch_idx): + def _step(self, batch, batch_idx): assert torch.is_autocast_enabled() output = self(batch) assert output.dtype == torch.float16 loss = self.loss(batch, output) - return {"loss": loss} + return loss + + def training_step(self, batch, batch_idx): + output = self._step(batch, batch_idx) + return {"loss": output} + + def validation_step(self, batch, batch_idx): + output = self._step(batch, batch_idx) + return {"x": output} + + def test_step(self, batch, batch_idx): + output = self._step(batch, batch_idx) + return {"y": output} + + def predict(self, batch, batch_idx, dataloader_idx=None): + assert torch.is_autocast_enabled() + output = self(batch) + assert output.dtype == torch.float16 + return output @pytest.mark.skip(reason='dp + amp not supported currently') # TODO @@ -54,6 +73,8 @@ def test_amp_single_gpu_dp(tmpdir): model = AMPTestModel() # tutils.run_model_test(trainer_options, model) trainer.fit(model) + trainer.test(model) + trainer.predict(model, DataLoader(RandomDataset(32, 64))) assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" @@ -73,6 +94,8 @@ def test_amp_single_gpu_ddp_spawn(tmpdir): model = AMPTestModel() # tutils.run_model_test(trainer_options, model) trainer.fit(model) + trainer.test(model) + trainer.predict(model, DataLoader(RandomDataset(32, 64))) assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" @@ -112,6 +135,8 @@ def test_amp_multi_gpu_ddp_spawn(tmpdir): model = AMPTestModel() # tutils.run_model_test(trainer_options, model) trainer.fit(model) + trainer.test(model) + trainer.predict(model, DataLoader(RandomDataset(32, 64))) assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"