diff --git a/CHANGELOG.md b/CHANGELOG.md index 0e465056457c3..1154218c546b8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -57,8 +57,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added missing val/test hooks in `LightningModule` ([#5467](https://github.com/PyTorchLightning/pytorch-lightning/pull/5467)) -- `Recall` and `Precision` metrics (and their functional counterparts `recall` and `precision`) can now be generalized to Recall@K and Precision@K with the use of `top_k` parameter ([#4842](https://github.com/PyTorchLightning/pytorch-lightning/pull/4842)) +- Added `gradient_clip_algorithm` argument to Trainer for gradient clipping by value ([#5477](https://github.com/PyTorchLightning/pytorch-lightning/pull/5477)). + +- `Recall` and `Precision` metrics (and their functional counterparts `recall` and `precision`) can now be generalized to Recall@K and Precision@K with the use of `top_k` parameter ([#4842](https://github.com/PyTorchLightning/pytorch-lightning/pull/4842)) - Added `ModelPruning` Callback ([#5618](https://github.com/PyTorchLightning/pytorch-lightning/pull/5618)) diff --git a/benchmarks/test_sharded_parity.py b/benchmarks/test_sharded_parity.py index 0dd11baee769f..c30ac9dc4ad08 100644 --- a/benchmarks/test_sharded_parity.py +++ b/benchmarks/test_sharded_parity.py @@ -164,6 +164,31 @@ def test_ddp_sharded_plugin_correctness_multi_gpu_multi_optim_manual(tmpdir): ) +@pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available") +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") +@pytest.mark.skipif(not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', + reason="test should be run outside of pytest") +@DDPLauncher.run("--accelerator ddp --gpus 2 --precision 16") +def test_ddp_sharded_plugin_clip_gradients(tmpdir, args=None): + plugin_parity_test( + gpus=args.gpus, + precision=args.precision, + accelerator=args.accelerator, + plugin=DDPShardedPlugin(), + model_cls=SeedTrainLoaderModel, + gradient_clip_val=0.001, + ) + plugin_parity_test( + gpus=args.gpus, + precision=args.precision, + accelerator=args.accelerator, + plugin=DDPShardedPlugin(), + model_cls=SeedTrainLoaderModel, + gradient_clip_val=0.001, + gradient_clip_algorithm='value', + ) + + class SeedTrainLoaderModel(BoringModel): """ Overrides training loader to ensure we enforce the same seed for all DDP processes. @@ -266,6 +291,8 @@ def plugin_parity_test( gpus: int = 0, precision: int = 32, max_percent_speed_diff: float = 0.1, + gradient_clip_val: Union[int, float] = 0, + gradient_clip_algorithm: str = 'norm', ): """ Ensures that the trained model is identical to the standard DDP implementation. @@ -279,6 +306,8 @@ def plugin_parity_test( gpus: Number of GPUS to enable. precision: Whether to use AMP or normal FP32 training. max_percent_speed_diff: The maximum speed difference compared to normal DDP training. + gradient_clip_val: 0 means don't clip. + gradient_clip_algorithm: 'value' means clip_by_value, 'norm' means clip_by_norm. defualt 'norm' This is more a safety net for variability in CI which can vary in speed, not for benchmarking. """ @@ -309,6 +338,8 @@ def plugin_parity_test( precision=precision, accelerator=accelerator, plugins=[plugin], + gradient_clip_val=gradient_clip_val, + gradient_clip_algorithm=gradient_clip_algorithm, ) max_memory_custom, custom_model_time = record_ddp_fit_model_stats( diff --git a/docs/source/advanced/training_tricks.rst b/docs/source/advanced/training_tricks.rst index d7230a1fd687a..0bf740afe6325 100644 --- a/docs/source/advanced/training_tricks.rst +++ b/docs/source/advanced/training_tricks.rst @@ -26,8 +26,10 @@ The effect is a large effective batch size of size KxN. Gradient Clipping ----------------- -Gradient clipping may be enabled to avoid exploding gradients. Specifically, this will `clip the gradient -norm `_ computed over all model parameters together. +Gradient clipping may be enabled to avoid exploding gradients. By default, this will `clip the gradient norm +`_ computed over all model parameters together. +If gradient_clip_algorithm option is set to 'value', which is 'norm' by default, this will +`clip the gradient value `_ for each parameter instead. .. seealso:: :class:`~pytorch_lightning.trainer.trainer.Trainer` @@ -39,6 +41,9 @@ norm `_ # clip gradients with norm above 0.5 trainer = Trainer(gradient_clip_val=0.5) + # clip gradients with value above 0.5 + trainer = Trainer(gradient_clip_val=0.5, gradient_clip_algorithm='value') + ---------- Auto scaling of batch size diff --git a/pytorch_lightning/accelerators/legacy/accelerator.py b/pytorch_lightning/accelerators/legacy/accelerator.py index 0788b26f845be..be93f79d47d44 100644 --- a/pytorch_lightning/accelerators/legacy/accelerator.py +++ b/pytorch_lightning/accelerators/legacy/accelerator.py @@ -21,6 +21,7 @@ from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.plugins.legacy.ddp_plugin import DDPPlugin from pytorch_lightning.plugins.legacy.rpc_plugin import RPCPlugin +from pytorch_lightning.utilities import GradClipAlgorithmType from pytorch_lightning.utilities.apply_func import move_data_to_device from pytorch_lightning.utilities.parsing import AttributeDict @@ -117,12 +118,16 @@ def clip_gradients(self, optimizer, clip_val=None): return self._clip_gradients(optimizer, grad_clip_val) - def _clip_gradients(self, optimizer: Optimizer, grad_clip_val: Union[float, int], norm_type: float = 2.0): + def _clip_gradients(self, optimizer: Optimizer, grad_clip_val: float, norm_type: float = 2.0): + clip_algorithm = self.trainer.gradient_clip_algorithm if self.trainer.amp_backend: - self.trainer.precision_connector.backend.clip_gradients(grad_clip_val, optimizer, norm_type) + self.trainer.precision_connector.backend.clip_gradients(grad_clip_val, clip_algorithm, optimizer, norm_type) else: model = self.trainer.get_model() - torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=grad_clip_val, norm_type=norm_type) + if clip_algorithm == GradClipAlgorithmType.VALUE: + torch.nn.utils.clip_grad_value_(model.parameters(), clip_value=grad_clip_val) + elif clip_algorithm == GradClipAlgorithmType.NORM: + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=grad_clip_val, norm_type=norm_type) def on_train_epoch_end(self, outputs): pass diff --git a/pytorch_lightning/accelerators/legacy/tpu_accelerator.py b/pytorch_lightning/accelerators/legacy/tpu_accelerator.py index 88b73fe94939f..bccb5e61aea5f 100644 --- a/pytorch_lightning/accelerators/legacy/tpu_accelerator.py +++ b/pytorch_lightning/accelerators/legacy/tpu_accelerator.py @@ -26,6 +26,7 @@ from pytorch_lightning.core import LightningModule from pytorch_lightning.utilities import ( _TPU_AVAILABLE, + GradClipAlgorithmType, move_data_to_device, rank_zero_info, rank_zero_only, @@ -245,27 +246,33 @@ def backward(self, closure_loss, optimizer, opt_idx, *args, **kwargs): return closure_loss - def _clip_gradients(self, optimizer: Optimizer, grad_clip_val: Union[float, int], norm_type: float = 2.0): - # this code is a modification of torch.nn.utils.clip_grad_norm_ + def _clip_gradients(self, + optimizer: Optimizer, + grad_clip_val: float, + gradient_clip_algorithm: str, + norm_type: float): + # this code contains a modification of torch.nn.utils.clip_grad_norm_ # with TPU support based on https://github.com/pytorch/xla/blob/master/TROUBLESHOOTING.md model = self.trainer.get_model() parameters = model.parameters() - max_norm = grad_clip_val - - if isinstance(parameters, torch.Tensor): - parameters = [parameters] - parameters = list(filter(lambda p: p.grad is not None, parameters)) - - device = parameters[0].device - out = torch.empty(len(parameters), device=device) - for i, p in enumerate(parameters): - torch.norm(p.grad.data.to(device), norm_type, out=out[i]) - total_norm = torch.norm(out, norm_type) - - clip_coef = torch.tensor(max_norm, device=device) / (total_norm + self.norm_clipping_epsilon) - clip_coef = torch.min(clip_coef, torch.ones_like(clip_coef)) - for p in parameters: - p.grad.data.mul_(clip_coef.to(p.grad.data.device)) + if gradient_clip_algorithm == GradClipAlgorithmType.VALUE: + torch.nn.utils.clip_grad_value_(parameters, clip_value=grad_clip_val) + elif gradient_clip_algorithm == GradClipAlgorithmType.NORM: + max_norm = grad_clip_val + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + parameters = list(filter(lambda p: p.grad is not None, parameters)) + + device = parameters[0].device + out = torch.empty(len(parameters), device=device) + for i, p in enumerate(parameters): + torch.norm(p.grad.data.to(device), norm_type, out=out[i]) + total_norm = torch.norm(out, norm_type) + + clip_coef = torch.tensor(max_norm, device=device) / (total_norm + self.norm_clipping_epsilon) + clip_coef = torch.min(clip_coef, torch.ones_like(clip_coef)) + for p in parameters: + p.grad.data.mul_(clip_coef.to(p.grad.data.device)) def barrier(self, name: Optional[str] = None): torch_xla.core.xla_model.rendezvous(f"pl.Trainer.{name}") diff --git a/pytorch_lightning/plugins/legacy/apex.py b/pytorch_lightning/plugins/legacy/apex.py index 49a9c57fd5927..179c8b8dfa2b7 100644 --- a/pytorch_lightning/plugins/legacy/apex.py +++ b/pytorch_lightning/plugins/legacy/apex.py @@ -11,14 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Tuple, Union +from typing import List, Tuple import torch from torch.optim.optimizer import Optimizer from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.plugins.legacy.precision_plugin import PrecisionPlugin -from pytorch_lightning.utilities import _APEX_AVAILABLE, AMPType +from pytorch_lightning.utilities import _APEX_AVAILABLE, AMPType, GradClipAlgorithmType from pytorch_lightning.utilities.distributed import rank_zero_warn if _APEX_AVAILABLE: @@ -98,34 +98,42 @@ def configure_apex(self, amp, model, optimizers, amp_level): model, optimizers = amp.initialize(model, optimizers, opt_level=amp_level) return model, optimizers - def clip_gradients(self, grad_clip_val: Union[int, float], optimizer: Optimizer, norm_type: float): + def clip_gradients(self, + optimizer: Optimizer, + grad_clip_val: float, + gradient_clip_algorithm: str, + norm_type: float): """ - This code is a modification of :meth:`torch.nn.utils.clip_grad_norm_` using a higher epsilon for fp16 weights. + This code contains modification of :meth:`torch.nn.utils.clip_grad_norm_` using a higher epsilon for fp16 weights. This is important when setting amp_level to O2, and the master weights are in fp16. Args: - grad_clip_val: Maximum norm of gradients. optimizer: Optimizer with gradients that will be clipped. - norm_type: (float or int): type of the used p-norm. Can be ``'inf'`` for - infinity norm. + grad_clip_val: Maximum norm of gradients. + gradient_clip_algorithm: 'value' means clip_by_value, 'norm' means clip_by_norm. + norm_type: type of the used p-norm. Can be ``'inf'`` for infinity norm. """ model = self.trainer.get_model() parameters = model.parameters() - max_norm = float(grad_clip_val) - - if isinstance(parameters, torch.Tensor): - parameters = [parameters] - parameters = [p for p in parameters if p.grad is not None] - - if len(parameters) == 0: - return torch.tensor(0.) - device = parameters[0].grad.device - total_norm = torch.norm( - torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type) - clip_coef = max_norm / (total_norm + self.norm_clipping_epsilon) - if clip_coef < 1: - for p in parameters: - p.grad.detach().mul_(clip_coef.to(p.grad.device)) + + if gradient_clip_algorithm == GradClipAlgorithmType.VALUE: + torch.nn.utils.clip_grad_value_(parameters, clip_value=grad_clip_val) + if gradient_clip_algorithm == GradClipAlgorithmType.NORM: + max_norm = float(grad_clip_val) + + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + parameters = [p for p in parameters if p.grad is not None] + + if len(parameters) == 0: + return torch.tensor(0.) + device = parameters[0].grad.device + total_norm = torch.norm( + torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type) + clip_coef = max_norm / (total_norm + self.norm_clipping_epsilon) + if clip_coef < 1: + for p in parameters: + p.grad.detach().mul_(clip_coef.to(p.grad.device)) @property def norm_clipping_epsilon(self): diff --git a/pytorch_lightning/plugins/legacy/native_amp.py b/pytorch_lightning/plugins/legacy/native_amp.py index 0a38a90acb79f..941042d9bc4ad 100644 --- a/pytorch_lightning/plugins/legacy/native_amp.py +++ b/pytorch_lightning/plugins/legacy/native_amp.py @@ -11,13 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Union import torch from torch.optim import Optimizer from pytorch_lightning.core.optimizer import LightningOptimizer from pytorch_lightning.plugins.legacy.precision_plugin import PrecisionPlugin +from pytorch_lightning.utilities import GradClipAlgorithmType class NativeAMPPlugin(PrecisionPlugin): @@ -60,9 +60,16 @@ def backward(self, closure_loss, optimizer, opt_idx, *args, **kwargs): return closure_loss - def clip_gradients(self, grad_clip_val: Union[int, float], optimizer: Optimizer, norm_type: float): + def clip_gradients(self, + optimizer: Optimizer, + grad_clip_val: float, + gradient_clip_algorithm: str, + norm_type: float): model = self.trainer.get_model() - torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=grad_clip_val, norm_type=norm_type) + if gradient_clip_algorithm == GradClipAlgorithmType.VALUE: + torch.nn.utils.clip_grad_value_(model.parameters(), clip_value=grad_clip_val) + elif gradient_clip_algorithm == GradClipAlgorithmType.NORM: + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=grad_clip_val, norm_type=norm_type) @property def scaler(self): diff --git a/pytorch_lightning/plugins/legacy/precision_plugin.py b/pytorch_lightning/plugins/legacy/precision_plugin.py index 1041e9d6b0faf..e7c918c1df809 100644 --- a/pytorch_lightning/plugins/legacy/precision_plugin.py +++ b/pytorch_lightning/plugins/legacy/precision_plugin.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Union from torch.optim import Optimizer @@ -35,5 +34,9 @@ def training_step(self, fx, args): def backward(self, closure_loss, optimizer, opt_idx, *args, **kwargs): raise NotImplementedError - def clip_gradients(self, grad_clip_val: Union[int, float], optimizer: Optimizer, norm_type: float): + def clip_gradients(self, + optimizer: Optimizer, + grad_clip_val: float, + gradient_clip_algorithm: str, + norm_type: float): raise NotImplementedError diff --git a/pytorch_lightning/plugins/legacy/sharded_native_amp_plugin.py b/pytorch_lightning/plugins/legacy/sharded_native_amp_plugin.py index b2523ef3fce0a..4c3586e42c3ed 100644 --- a/pytorch_lightning/plugins/legacy/sharded_native_amp_plugin.py +++ b/pytorch_lightning/plugins/legacy/sharded_native_amp_plugin.py @@ -11,12 +11,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + from typing import cast, Union +import torch from torch.optim import Optimizer from pytorch_lightning.plugins.legacy.native_amp import NativeAMPPlugin -from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE, _NATIVE_AMP_AVAILABLE +from pytorch_lightning.utilities import ( + _FAIRSCALE_AVAILABLE, + _NATIVE_AMP_AVAILABLE, + GradClipAlgorithmType, +) if _NATIVE_AMP_AVAILABLE and _FAIRSCALE_AVAILABLE: from fairscale.optim import OSS @@ -28,8 +34,15 @@ class ShardedNativeAMPPlugin(NativeAMPPlugin): def scaler(self): return ShardedGradScaler() - def clip_gradients(self, grad_clip_val: Union[int, float], optimizer: Optimizer, norm_type: float): - max_norm = grad_clip_val - norm_type = float(2.0) - optimizer = cast(OSS, optimizer) - optimizer.clip_grad_norm(max_norm, norm_type=norm_type) + def clip_gradients(self, + optimizer: Optimizer, + grad_clip_val: float, + gradient_clip_algorithm: str, + norm_type: float): + + if gradient_clip_algorithm == GradClipAlgorithmType.VALUE: + model = self.trainer.get_model() + torch.nn.utils.clip_grad_value_(model.parameters(), clip_value=grad_clip_val) + elif gradient_clip_algorithm == GradClipAlgorithmType.NORM: + optimizer = cast(OSS, optimizer) + optimizer.clip_grad_norm(grad_clip_val, norm_type=norm_type) diff --git a/pytorch_lightning/trainer/connectors/training_trick_connector.py b/pytorch_lightning/trainer/connectors/training_trick_connector.py index b5d1d45461cff..a33c929d8fe49 100644 --- a/pytorch_lightning/trainer/connectors/training_trick_connector.py +++ b/pytorch_lightning/trainer/connectors/training_trick_connector.py @@ -11,7 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + + from pytorch_lightning.callbacks import GradientAccumulationScheduler +from pytorch_lightning.utilities import GradClipAlgorithmType from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -23,6 +26,7 @@ def __init__(self, trainer): def on_trainer_init( self, gradient_clip_val, + gradient_clip_algorithm, track_grad_norm, accumulate_grad_batches, truncated_bptt_steps, @@ -32,7 +36,11 @@ def on_trainer_init( self.trainer.terminate_on_nan = terminate_on_nan # gradient clipping + if gradient_clip_algorithm not in [GradClipAlgorithmType.VALUE, GradClipAlgorithmType.NORM]: + raise MisconfigurationException(f"gradient_clip_algorithm should be " + f"'{GradClipAlgorithmType.VALUE}' or '{GradClipAlgorithmType.NORM}'") self.trainer.gradient_clip_val = gradient_clip_val + self.trainer.gradient_clip_algorithm = gradient_clip_algorithm # gradient norm tracking if not isinstance(track_grad_norm, (int, float)) and track_grad_norm != 'inf': diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index ba34c49581038..51be66a5d2f30 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -87,7 +87,8 @@ def __init__( checkpoint_callback: bool = True, callbacks: Optional[Union[List[Callback], Callback]] = None, default_root_dir: Optional[str] = None, - gradient_clip_val: float = 0, + gradient_clip_val: Union[int, float] = 0, + gradient_clip_algorithm: str = 'norm', process_position: int = 0, num_nodes: int = 1, num_processes: int = 1, @@ -197,6 +198,8 @@ def __init__( gradient_clip_val: 0 means don't clip. + gradient_clip_algorithm: 'value' means clip_by_value, 'norm' means clip_by_norm. defualt 'norm' + limit_train_batches: How much of training dataset to check (floats = percent, int = num_batches) limit_val_batches: How much of validation dataset to check (floats = percent, int = num_batches) @@ -349,7 +352,12 @@ def __init__( # init training tricks self.training_tricks_connector.on_trainer_init( - gradient_clip_val, track_grad_norm, accumulate_grad_batches, truncated_bptt_steps, terminate_on_nan + gradient_clip_val, + gradient_clip_algorithm, + track_grad_norm, + accumulate_grad_batches, + truncated_bptt_steps, + terminate_on_nan, ) # init accelerator related flags diff --git a/pytorch_lightning/utilities/__init__.py b/pytorch_lightning/utilities/__init__.py index a8f3e134936ff..fcb1db43d6933 100644 --- a/pytorch_lightning/utilities/__init__.py +++ b/pytorch_lightning/utilities/__init__.py @@ -22,7 +22,13 @@ rank_zero_only, rank_zero_warn, ) -from pytorch_lightning.utilities.enums import AMPType, DeviceType, DistributedType, LightningEnum # noqa: F401 +from pytorch_lightning.utilities.enums import ( # noqa: F401 + AMPType, + DeviceType, + DistributedType, + GradClipAlgorithmType, + LightningEnum, +) from pytorch_lightning.utilities.imports import ( # noqa: F401 _APEX_AVAILABLE, _BOLTS_AVAILABLE, diff --git a/pytorch_lightning/utilities/enums.py b/pytorch_lightning/utilities/enums.py index f6c0bf1d6cc54..75ce4ab069650 100644 --- a/pytorch_lightning/utilities/enums.py +++ b/pytorch_lightning/utilities/enums.py @@ -80,3 +80,13 @@ class DeviceType(LightningEnum): CPU = 'CPU' GPU = 'GPU' TPU = 'TPU' + + +class GradClipAlgorithmType(LightningEnum): + """ Define gradient_clip_algorithm types - training-tricks. + + >>> GradClipAlgorithmType.VALUE in ('value', 'norm') + True + """ + VALUE = 'value' + NORM = 'norm' diff --git a/tests/models/test_horovod.py b/tests/models/test_horovod.py index 85e91c4ae9d84..540123e3489a8 100644 --- a/tests/models/test_horovod.py +++ b/tests/models/test_horovod.py @@ -17,6 +17,7 @@ import shlex import subprocess import sys +from copy import deepcopy import numpy as np import pytest @@ -84,6 +85,11 @@ def test_horovod_cpu(tmpdir): ) _run_horovod(trainer_options) + # clip_grad_by_value test + trainer_options_clip_grad_val = deepcopy(trainer_options) + trainer_options_clip_grad_val.update({'gradient_clip_algorithm': 'value'}) + _run_horovod(trainer_options_clip_grad_val) + @pytest.mark.skipif(platform.system() == "Windows", reason="Horovod is not supported on Windows") def test_horovod_cpu_implicit(tmpdir): @@ -100,6 +106,11 @@ def test_horovod_cpu_implicit(tmpdir): ) _run_horovod(trainer_options) + # clip_grad_by_value test + trainer_options_clip_grad_val = deepcopy(trainer_options) + trainer_options_clip_grad_val.update({'gradient_clip_algorithm': 'value'}) + _run_horovod(trainer_options_clip_grad_val) + @pytest.mark.skipif(platform.system() == "Windows", reason="Horovod is not supported on Windows") @pytest.mark.skipif(not _HOROVOD_NCCL_AVAILABLE, reason="test requires Horovod with NCCL support") @@ -120,6 +131,11 @@ def test_horovod_multi_gpu(tmpdir): ) _run_horovod(trainer_options, on_gpu=True) + # clip_grad_by_value test + trainer_options_clip_grad_val = deepcopy(trainer_options) + trainer_options_clip_grad_val.update({'gradient_clip_algorithm': 'value'}) + _run_horovod(trainer_options_clip_grad_val, on_gpu=True) + @pytest.mark.skipif(platform.system() == "Windows", reason="Horovod is not supported on Windows") @pytest.mark.skipif(not _HOROVOD_NCCL_AVAILABLE, reason="test requires Horovod with NCCL support") @@ -143,6 +159,11 @@ def test_horovod_apex(tmpdir): ) _run_horovod(trainer_options, on_gpu=True) + # clip_grad_by_value test + trainer_options_clip_grad_val = deepcopy(trainer_options) + trainer_options_clip_grad_val.update({'gradient_clip_algorithm': 'value'}) + _run_horovod(trainer_options_clip_grad_val, on_gpu=True) + @pytest.mark.skip(reason="Skip till Horovod fixes integration with Native torch.cuda.amp") @pytest.mark.skipif(platform.system() == "Windows", reason="Horovod is not supported on Windows") @@ -167,6 +188,11 @@ def test_horovod_amp(tmpdir): ) _run_horovod(trainer_options, on_gpu=True) + # clip_grad_by_value test + trainer_options_clip_grad_val = deepcopy(trainer_options) + trainer_options_clip_grad_val.update({'gradient_clip_algorithm': 'value'}) + _run_horovod(trainer_options_clip_grad_val, on_gpu=True) + @pytest.mark.skipif(platform.system() == "Windows", reason="Horovod is not supported on Windows") @pytest.mark.skipif(not _HOROVOD_NCCL_AVAILABLE, reason="test requires Horovod with NCCL support") diff --git a/tests/models/test_tpu.py b/tests/models/test_tpu.py index 5e977eed765d0..61fcadf7b28da 100644 --- a/tests/models/test_tpu.py +++ b/tests/models/test_tpu.py @@ -197,6 +197,25 @@ def test_tpu_grad_norm(tmpdir): tpipes.run_model_test(trainer_options, model, on_gpu=False, with_hpc=False) +@pytest.mark.skipif(not _TPU_AVAILABLE, reason="test requires TPU machine") +@pl_multi_process_test +def test_tpu_clip_grad_by_value(tmpdir): + """Test if clip_gradients by value works on TPU.""" + trainer_options = dict( + default_root_dir=tmpdir, + progress_bar_refresh_rate=0, + max_epochs=1, + tpu_cores=1, + limit_train_batches=0.4, + limit_val_batches=0.4, + gradient_clip_val=0.1, + gradient_clip_algorithm='value' + ) + + model = EvalModelTemplate() + tpipes.run_model_test(trainer_options, model, on_gpu=False, with_hpc=False) + + @pytest.mark.skipif(not _TPU_AVAILABLE, reason="test requires TPU machine") @pl_multi_process_test def test_dataloaders_passed_to_fit(tmpdir): diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index e21351704fd4c..7376580143331 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -966,6 +966,46 @@ def training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, hidde trainer.fit(model) +def test_gradient_clipping_by_value(tmpdir): + """ + Test gradient clipping by value + """ + tutils.reset_seed() + + model = EvalModelTemplate() + + grad_clip_val = 0.0001 + trainer = Trainer( + max_steps=10, + max_epochs=1, + gradient_clip_val=grad_clip_val, + gradient_clip_algorithm='value', + default_root_dir=tmpdir, + ) + + trainer.train_loop.old_training_step_and_backward = trainer.train_loop.training_step_and_backward + + def training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, hiddens): + """ + wrap the forward step in a closure so second order methods work + """ + # test that gradient is clipped correctly + ret_val = trainer.train_loop.old_training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, hiddens) + parameters = model.parameters() + grad_max_list = [torch.max(p.grad.detach().abs()) for p in parameters] + grad_max = torch.max(torch.stack(grad_max_list)) + assert round(grad_max.item(), 6) <= grad_clip_val, \ + f"Gradient max value {grad_max} > grad_clip_val {grad_clip_val} ." + + return ret_val + + trainer.train_loop.training_step_and_backward = training_step_and_backward + # for the test + model.prev_called_batch_idx = 0 + + trainer.fit(model) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine") @pytest.mark.skipif(not _NATIVE_AMP_AVAILABLE, reason="test requires native AMP.") def test_gradient_clipping_fp16(tmpdir): @@ -1005,6 +1045,47 @@ def training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, hidde trainer.fit(model) +@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine") +@pytest.mark.skipif(not _NATIVE_AMP_AVAILABLE, reason="test requires native AMP.") +def test_gradient_clipping_by_value_fp16(tmpdir): + """ + Test gradient clipping by value with fp16 + """ + tutils.reset_seed() + + model = EvalModelTemplate() + grad_clip_val = 0.0001 + trainer = Trainer( + max_steps=10, + max_epochs=1, + precision=16, + gpus=1, + gradient_clip_val=grad_clip_val, + gradient_clip_algorithm='value', + default_root_dir=tmpdir, + ) + + trainer.train_loop.old_training_step_and_backward = trainer.train_loop.training_step_and_backward + + def training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, hiddens): + """ + wrap the forward step in a closure so second order methods work + """ + # test that gradient is clipped correctly + ret_val = trainer.train_loop.old_training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, hiddens) + parameters = model.parameters() + grad_max = torch.max(torch.stack([p.grad.detach() for p in parameters])) + assert round(grad_max.item(), 6) <= grad_clip_val, \ + f"Gradient max value {grad_max} > grad_clip_val {grad_clip_val} ." + + return ret_val + + trainer.train_loop.training_step_and_backward = training_step_and_backward + model.prev_called_batch_idx = 0 + + trainer.fit(model) + + def test_gpu_choice(tmpdir): trainer_options = dict( default_root_dir=tmpdir,