Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
4111955
add clip_grad_by_value feature
dhkim0225 Jan 12, 2021
c09e990
write changelog, training_tricks.rst
dhkim0225 Jan 12, 2021
baa9a49
add end line to sharded_natvie_amp_pluigin.py
dhkim0225 Jan 12, 2021
a236611
bugfix update regex
dhkim0225 Jan 12, 2021
749e286
update regex documentation
dhkim0225 Jan 12, 2021
b83ea7f
revert changelog
dhkim0225 Jan 18, 2021
8a18422
Merge pull request #1 from PyTorchLightning/release/1.2-dev
dhkim0225 Jan 18, 2021
9913351
commit based on review
dhkim0225 Jan 18, 2021
af9434e
Merge branch 'release/1.2-dev' into feature/clip_grad_by_value_1.2-dev
dhkim0225 Jan 18, 2021
65e694f
Merge branch 'release/1.2-dev' into feature/clip_grad_by_value_1.2-dev
tchaton Jan 19, 2021
4c8e46b
Add Enum Type
dhkim0225 Jan 20, 2021
e2484ae
edit CHANGELOG.md to prevent conflicts
dhkim0225 Jan 20, 2021
393b77b
add test codes
dhkim0225 Jan 20, 2021
2403f96
pep8 formatting
dhkim0225 Jan 20, 2021
79d4149
update test codes
dhkim0225 Jan 20, 2021
9fc6f62
update test codes
dhkim0225 Jan 20, 2021
853fb09
update test codes
dhkim0225 Jan 20, 2021
d04db18
Merge branch 'release/1.2-dev' into feature/clip_grad_by_value_1.2-dev
dhkim0225 Jan 22, 2021
257bcb6
add value clipping for sharded ddp
dhkim0225 Jan 22, 2021
06aa1d3
Merge branch 'release/1.2-dev' into feature/clip_grad_by_value_1.2-dev
dhkim0225 Jan 22, 2021
b32048f
Merge branch 'release/1.2-dev' into feature/clip_grad_by_value_1.2-dev
dhkim0225 Jan 26, 2021
6d3bf7c
Merge branch 'release/1.2-dev' into feature/clip_grad_by_value_1.2-dev
dhkim0225 Jan 26, 2021
cc6d794
Merge branch 'release/1.2-dev' of https://github.com/PyTorchLightning…
dhkim0225 Jan 27, 2021
7737ae5
Merge branch 'release/1.2-dev' into feature/clip_grad_by_value_1.2-dev
dhkim0225 Jan 28, 2021
dcf9ff0
remove bad line in native_amp.py
dhkim0225 Jan 29, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added missing val/test hooks in `LightningModule` ([#5467](https://github.com/PyTorchLightning/pytorch-lightning/pull/5467))


- `Recall` and `Precision` metrics (and their functional counterparts `recall` and `precision`) can now be generalized to Recall@K and Precision@K with the use of `top_k` parameter ([#4842](https://github.com/PyTorchLightning/pytorch-lightning/pull/4842))
- Added `gradient_clip_algorithm` argument to Trainer for gradient clipping by value ([#5477](https://github.com/PyTorchLightning/pytorch-lightning/pull/5477)).


- `Recall` and `Precision` metrics (and their functional counterparts `recall` and `precision`) can now be generalized to Recall@K and Precision@K with the use of `top_k` parameter ([#4842](https://github.com/PyTorchLightning/pytorch-lightning/pull/4842))

- Added `ModelPruning` Callback ([#5618](https://github.com/PyTorchLightning/pytorch-lightning/pull/5618))

Expand Down
31 changes: 31 additions & 0 deletions benchmarks/test_sharded_parity.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,31 @@ def test_ddp_sharded_plugin_correctness_multi_gpu_multi_optim_manual(tmpdir):
)


@pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
@pytest.mark.skipif(not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1',
reason="test should be run outside of pytest")
@DDPLauncher.run("--accelerator ddp --gpus 2 --precision 16")
def test_ddp_sharded_plugin_clip_gradients(tmpdir, args=None):
plugin_parity_test(
gpus=args.gpus,
precision=args.precision,
accelerator=args.accelerator,
plugin=DDPShardedPlugin(),
model_cls=SeedTrainLoaderModel,
gradient_clip_val=0.001,
)
plugin_parity_test(
gpus=args.gpus,
precision=args.precision,
accelerator=args.accelerator,
plugin=DDPShardedPlugin(),
model_cls=SeedTrainLoaderModel,
gradient_clip_val=0.001,
gradient_clip_algorithm='value',
)


class SeedTrainLoaderModel(BoringModel):
"""
Overrides training loader to ensure we enforce the same seed for all DDP processes.
Expand Down Expand Up @@ -266,6 +291,8 @@ def plugin_parity_test(
gpus: int = 0,
precision: int = 32,
max_percent_speed_diff: float = 0.1,
gradient_clip_val: Union[int, float] = 0,
gradient_clip_algorithm: str = 'norm',
):
"""
Ensures that the trained model is identical to the standard DDP implementation.
Expand All @@ -279,6 +306,8 @@ def plugin_parity_test(
gpus: Number of GPUS to enable.
precision: Whether to use AMP or normal FP32 training.
max_percent_speed_diff: The maximum speed difference compared to normal DDP training.
gradient_clip_val: 0 means don't clip.
gradient_clip_algorithm: 'value' means clip_by_value, 'norm' means clip_by_norm. defualt 'norm'
This is more a safety net for variability in CI which can vary in speed, not for benchmarking.

"""
Expand Down Expand Up @@ -309,6 +338,8 @@ def plugin_parity_test(
precision=precision,
accelerator=accelerator,
plugins=[plugin],
gradient_clip_val=gradient_clip_val,
gradient_clip_algorithm=gradient_clip_algorithm,
)

max_memory_custom, custom_model_time = record_ddp_fit_model_stats(
Expand Down
9 changes: 7 additions & 2 deletions docs/source/advanced/training_tricks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,10 @@ The effect is a large effective batch size of size KxN.

Gradient Clipping
-----------------
Gradient clipping may be enabled to avoid exploding gradients. Specifically, this will `clip the gradient
norm <https://pytorch.org/docs/stable/nn.html#torch.nn.utils.clip_grad_norm_>`_ computed over all model parameters together.
Gradient clipping may be enabled to avoid exploding gradients. By default, this will `clip the gradient norm
<https://pytorch.org/docs/stable/nn.html#torch.nn.utils.clip_grad_norm_>`_ computed over all model parameters together.
If gradient_clip_algorithm option is set to 'value', which is 'norm' by default, this will
`clip the gradient value <https://pytorch.org/docs/stable/nn.html#torch.nn.utils.clip_grad_value_>`_ for each parameter instead.

.. seealso:: :class:`~pytorch_lightning.trainer.trainer.Trainer`

Expand All @@ -39,6 +41,9 @@ norm <https://pytorch.org/docs/stable/nn.html#torch.nn.utils.clip_grad_norm_>`_
# clip gradients with norm above 0.5
trainer = Trainer(gradient_clip_val=0.5)

# clip gradients with value above 0.5
trainer = Trainer(gradient_clip_val=0.5, gradient_clip_algorithm='value')

----------

Auto scaling of batch size
Expand Down
11 changes: 8 additions & 3 deletions pytorch_lightning/accelerators/legacy/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.plugins.legacy.ddp_plugin import DDPPlugin
from pytorch_lightning.plugins.legacy.rpc_plugin import RPCPlugin
from pytorch_lightning.utilities import GradClipAlgorithmType
from pytorch_lightning.utilities.apply_func import move_data_to_device
from pytorch_lightning.utilities.parsing import AttributeDict

Expand Down Expand Up @@ -117,12 +118,16 @@ def clip_gradients(self, optimizer, clip_val=None):
return
self._clip_gradients(optimizer, grad_clip_val)

def _clip_gradients(self, optimizer: Optimizer, grad_clip_val: Union[float, int], norm_type: float = 2.0):
def _clip_gradients(self, optimizer: Optimizer, grad_clip_val: float, norm_type: float = 2.0):
clip_algorithm = self.trainer.gradient_clip_algorithm
if self.trainer.amp_backend:
self.trainer.precision_connector.backend.clip_gradients(grad_clip_val, optimizer, norm_type)
self.trainer.precision_connector.backend.clip_gradients(grad_clip_val, clip_algorithm, optimizer, norm_type)
else:
model = self.trainer.get_model()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=grad_clip_val, norm_type=norm_type)
if clip_algorithm == GradClipAlgorithmType.VALUE:
torch.nn.utils.clip_grad_value_(model.parameters(), clip_value=grad_clip_val)
elif clip_algorithm == GradClipAlgorithmType.NORM:
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=grad_clip_val, norm_type=norm_type)

def on_train_epoch_end(self, outputs):
pass
Expand Down
43 changes: 25 additions & 18 deletions pytorch_lightning/accelerators/legacy/tpu_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from pytorch_lightning.core import LightningModule
from pytorch_lightning.utilities import (
_TPU_AVAILABLE,
GradClipAlgorithmType,
move_data_to_device,
rank_zero_info,
rank_zero_only,
Expand Down Expand Up @@ -245,27 +246,33 @@ def backward(self, closure_loss, optimizer, opt_idx, *args, **kwargs):

return closure_loss

def _clip_gradients(self, optimizer: Optimizer, grad_clip_val: Union[float, int], norm_type: float = 2.0):
# this code is a modification of torch.nn.utils.clip_grad_norm_
def _clip_gradients(self,
optimizer: Optimizer,
grad_clip_val: float,
gradient_clip_algorithm: str,
norm_type: float):
# this code contains a modification of torch.nn.utils.clip_grad_norm_
# with TPU support based on https://github.com/pytorch/xla/blob/master/TROUBLESHOOTING.md
model = self.trainer.get_model()
parameters = model.parameters()
max_norm = grad_clip_val

if isinstance(parameters, torch.Tensor):
parameters = [parameters]
parameters = list(filter(lambda p: p.grad is not None, parameters))

device = parameters[0].device
out = torch.empty(len(parameters), device=device)
for i, p in enumerate(parameters):
torch.norm(p.grad.data.to(device), norm_type, out=out[i])
total_norm = torch.norm(out, norm_type)

clip_coef = torch.tensor(max_norm, device=device) / (total_norm + self.norm_clipping_epsilon)
clip_coef = torch.min(clip_coef, torch.ones_like(clip_coef))
for p in parameters:
p.grad.data.mul_(clip_coef.to(p.grad.data.device))
if gradient_clip_algorithm == GradClipAlgorithmType.VALUE:
torch.nn.utils.clip_grad_value_(parameters, clip_value=grad_clip_val)
elif gradient_clip_algorithm == GradClipAlgorithmType.NORM:
max_norm = grad_clip_val
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
parameters = list(filter(lambda p: p.grad is not None, parameters))

device = parameters[0].device
out = torch.empty(len(parameters), device=device)
for i, p in enumerate(parameters):
torch.norm(p.grad.data.to(device), norm_type, out=out[i])
total_norm = torch.norm(out, norm_type)

clip_coef = torch.tensor(max_norm, device=device) / (total_norm + self.norm_clipping_epsilon)
clip_coef = torch.min(clip_coef, torch.ones_like(clip_coef))
for p in parameters:
p.grad.data.mul_(clip_coef.to(p.grad.data.device))

def barrier(self, name: Optional[str] = None):
torch_xla.core.xla_model.rendezvous(f"pl.Trainer.{name}")
Expand Down
52 changes: 30 additions & 22 deletions pytorch_lightning/plugins/legacy/apex.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Tuple, Union
from typing import List, Tuple

import torch
from torch.optim.optimizer import Optimizer

from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.plugins.legacy.precision_plugin import PrecisionPlugin
from pytorch_lightning.utilities import _APEX_AVAILABLE, AMPType
from pytorch_lightning.utilities import _APEX_AVAILABLE, AMPType, GradClipAlgorithmType
from pytorch_lightning.utilities.distributed import rank_zero_warn

if _APEX_AVAILABLE:
Expand Down Expand Up @@ -98,34 +98,42 @@ def configure_apex(self, amp, model, optimizers, amp_level):
model, optimizers = amp.initialize(model, optimizers, opt_level=amp_level)
return model, optimizers

def clip_gradients(self, grad_clip_val: Union[int, float], optimizer: Optimizer, norm_type: float):
def clip_gradients(self,
optimizer: Optimizer,
grad_clip_val: float,
gradient_clip_algorithm: str,
norm_type: float):
"""
This code is a modification of :meth:`torch.nn.utils.clip_grad_norm_` using a higher epsilon for fp16 weights.
This code contains modification of :meth:`torch.nn.utils.clip_grad_norm_` using a higher epsilon for fp16 weights.
This is important when setting amp_level to O2, and the master weights are in fp16.

Args:
grad_clip_val: Maximum norm of gradients.
optimizer: Optimizer with gradients that will be clipped.
norm_type: (float or int): type of the used p-norm. Can be ``'inf'`` for
infinity norm.
grad_clip_val: Maximum norm of gradients.
gradient_clip_algorithm: 'value' means clip_by_value, 'norm' means clip_by_norm.
norm_type: type of the used p-norm. Can be ``'inf'`` for infinity norm.
"""
model = self.trainer.get_model()
parameters = model.parameters()
max_norm = float(grad_clip_val)

if isinstance(parameters, torch.Tensor):
parameters = [parameters]
parameters = [p for p in parameters if p.grad is not None]

if len(parameters) == 0:
return torch.tensor(0.)
device = parameters[0].grad.device
total_norm = torch.norm(
torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type)
clip_coef = max_norm / (total_norm + self.norm_clipping_epsilon)
if clip_coef < 1:
for p in parameters:
p.grad.detach().mul_(clip_coef.to(p.grad.device))

if gradient_clip_algorithm == GradClipAlgorithmType.VALUE:
torch.nn.utils.clip_grad_value_(parameters, clip_value=grad_clip_val)
if gradient_clip_algorithm == GradClipAlgorithmType.NORM:
max_norm = float(grad_clip_val)

if isinstance(parameters, torch.Tensor):
parameters = [parameters]
parameters = [p for p in parameters if p.grad is not None]

if len(parameters) == 0:
return torch.tensor(0.)
device = parameters[0].grad.device
total_norm = torch.norm(
torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type)
clip_coef = max_norm / (total_norm + self.norm_clipping_epsilon)
if clip_coef < 1:
for p in parameters:
p.grad.detach().mul_(clip_coef.to(p.grad.device))

@property
def norm_clipping_epsilon(self):
Expand Down
13 changes: 10 additions & 3 deletions pytorch_lightning/plugins/legacy/native_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Union

import torch
from torch.optim import Optimizer

from pytorch_lightning.core.optimizer import LightningOptimizer
from pytorch_lightning.plugins.legacy.precision_plugin import PrecisionPlugin
from pytorch_lightning.utilities import GradClipAlgorithmType


class NativeAMPPlugin(PrecisionPlugin):
Expand Down Expand Up @@ -60,9 +60,16 @@ def backward(self, closure_loss, optimizer, opt_idx, *args, **kwargs):

return closure_loss

def clip_gradients(self, grad_clip_val: Union[int, float], optimizer: Optimizer, norm_type: float):
def clip_gradients(self,
optimizer: Optimizer,
grad_clip_val: float,
gradient_clip_algorithm: str,
norm_type: float):
model = self.trainer.get_model()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=grad_clip_val, norm_type=norm_type)
if gradient_clip_algorithm == GradClipAlgorithmType.VALUE:
torch.nn.utils.clip_grad_value_(model.parameters(), clip_value=grad_clip_val)
elif gradient_clip_algorithm == GradClipAlgorithmType.NORM:
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=grad_clip_val, norm_type=norm_type)

@property
def scaler(self):
Expand Down
7 changes: 5 additions & 2 deletions pytorch_lightning/plugins/legacy/precision_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Union

from torch.optim import Optimizer

Expand All @@ -35,5 +34,9 @@ def training_step(self, fx, args):
def backward(self, closure_loss, optimizer, opt_idx, *args, **kwargs):
raise NotImplementedError

def clip_gradients(self, grad_clip_val: Union[int, float], optimizer: Optimizer, norm_type: float):
def clip_gradients(self,
optimizer: Optimizer,
grad_clip_val: float,
gradient_clip_algorithm: str,
norm_type: float):
raise NotImplementedError
25 changes: 19 additions & 6 deletions pytorch_lightning/plugins/legacy/sharded_native_amp_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,18 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import cast, Union

import torch
from torch.optim import Optimizer

from pytorch_lightning.plugins.legacy.native_amp import NativeAMPPlugin
from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE, _NATIVE_AMP_AVAILABLE
from pytorch_lightning.utilities import (
_FAIRSCALE_AVAILABLE,
_NATIVE_AMP_AVAILABLE,
GradClipAlgorithmType,
)

if _NATIVE_AMP_AVAILABLE and _FAIRSCALE_AVAILABLE:
from fairscale.optim import OSS
Expand All @@ -28,8 +34,15 @@ class ShardedNativeAMPPlugin(NativeAMPPlugin):
def scaler(self):
return ShardedGradScaler()

def clip_gradients(self, grad_clip_val: Union[int, float], optimizer: Optimizer, norm_type: float):
max_norm = grad_clip_val
norm_type = float(2.0)
optimizer = cast(OSS, optimizer)
optimizer.clip_grad_norm(max_norm, norm_type=norm_type)
def clip_gradients(self,
optimizer: Optimizer,
grad_clip_val: float,
gradient_clip_algorithm: str,
norm_type: float):

if gradient_clip_algorithm == GradClipAlgorithmType.VALUE:
model = self.trainer.get_model()
torch.nn.utils.clip_grad_value_(model.parameters(), clip_value=grad_clip_val)
elif gradient_clip_algorithm == GradClipAlgorithmType.NORM:
optimizer = cast(OSS, optimizer)
optimizer.clip_grad_norm(grad_clip_val, norm_type=norm_type)
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from pytorch_lightning.callbacks import GradientAccumulationScheduler
from pytorch_lightning.utilities import GradClipAlgorithmType
from pytorch_lightning.utilities.exceptions import MisconfigurationException


Expand All @@ -23,6 +26,7 @@ def __init__(self, trainer):
def on_trainer_init(
self,
gradient_clip_val,
gradient_clip_algorithm,
track_grad_norm,
accumulate_grad_batches,
truncated_bptt_steps,
Expand All @@ -32,7 +36,11 @@ def on_trainer_init(
self.trainer.terminate_on_nan = terminate_on_nan

# gradient clipping
if gradient_clip_algorithm not in [GradClipAlgorithmType.VALUE, GradClipAlgorithmType.NORM]:
raise MisconfigurationException(f"gradient_clip_algorithm should be "
f"'{GradClipAlgorithmType.VALUE}' or '{GradClipAlgorithmType.NORM}'")
self.trainer.gradient_clip_val = gradient_clip_val
self.trainer.gradient_clip_algorithm = gradient_clip_algorithm

# gradient norm tracking
if not isinstance(track_grad_norm, (int, float)) and track_grad_norm != 'inf':
Expand Down
Loading