-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Add Trainer(gradient_clip_algorithm='value'|'norm')
#6123
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
09ea112
c0e8064
ca8e6fd
87f12c1
8e43b8a
2bb9924
5b83f0d
caafdf2
ca774b6
5a741e2
0568e3a
1a4e79e
4a813c1
2f5cb3e
f4275a2
e92ec69
ac701ce
bc20fa4
b842210
5ec2ebd
d37fbbc
952c778
b8fdbe1
c4cccf0
6bd4793
3aeba85
902a33c
7467616
5463830
2e933d4
b1e26e6
cedf5f6
f5bb45d
42fc5f6
e55b90c
308ce38
903f2e2
28c948a
177a1c9
fc23845
d99a650
f80aa8d
fb895b6
caa0bbf
f1f9015
5dfe5ef
50a6c74
c337b12
f7a4fda
48c3dd8
e7e3b47
2997536
fb34e84
8e665ec
9575774
7c16f6a
fec189a
1e80304
4d5e05f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,6 +17,7 @@ | |
| import torch | ||
|
|
||
| from pytorch_lightning.plugins.base_plugin import Plugin | ||
| from pytorch_lightning.utilities import GradClipAlgorithmType | ||
|
|
||
| if TYPE_CHECKING: | ||
| from torch.nn import Module | ||
|
|
@@ -33,6 +34,13 @@ class PrecisionPlugin(Plugin): | |
| EPSILON: float = 1e-6 | ||
| precision: Union[str, int] = 32 | ||
|
|
||
| def __init__(self) -> None: | ||
| super().__init__() | ||
| self.clip_grad_funcs = { | ||
| GradClipAlgorithmType.VALUE: self.clip_grad_by_value, | ||
| GradClipAlgorithmType.NORM: self.clip_grad_by_norm, | ||
| } | ||
|
|
||
| def master_params(self, optimizer: 'Optimizer') -> Generator[torch.Tensor, None, None]: | ||
| """The master params of the model. Returns the plain model params here. | ||
| Maybe different in other precision plugins. | ||
|
|
@@ -103,20 +111,29 @@ def clip_gradients( | |
| model: 'LightningModule', | ||
| optimizer: 'Optimizer', | ||
| clip_val: Union[int, float], | ||
| norm_type: float = 2.0 | ||
| gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM, | ||
| ) -> None: | ||
| """Clips the gradients to a specific value""" | ||
| """Clips the gradients""" | ||
| if clip_val is None: | ||
| return | ||
|
|
||
| grad_clip_val = float(clip_val) | ||
|
|
||
| if grad_clip_val <= 0: | ||
| clip_val = float(clip_val) | ||
| if clip_val <= 0: | ||
| return | ||
|
|
||
| clip_grad_func = self.clip_grad_funcs[gradient_clip_algorithm] | ||
| clip_grad_func(optimizer, clip_val) # type: ignore | ||
|
|
||
| def clip_grad_by_value(self, optimizer: 'Optimizer', clip_val: Union[int, float]) -> None: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we make them private ?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @tchaton what's expected when other precision plugins override them, as the sharded native amp one does?
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. or protected :] |
||
| """Clip gradients by value""" | ||
| parameters = list(self.master_params(optimizer)) | ||
| torch.nn.utils.clip_grad_value_(parameters, clip_value=clip_val) | ||
|
|
||
| max_norm = grad_clip_val | ||
| def clip_grad_by_norm(self, optimizer: 'Optimizer', clip_val: Union[int, float], norm_type: float = 2.0) -> None: | ||
| """Clip gradients by norm""" | ||
| # TODO: separate TPU case from here | ||
| parameters = list(self.master_params(optimizer)) | ||
| max_norm = clip_val | ||
|
|
||
| if isinstance(parameters, torch.Tensor): | ||
| parameters = [parameters] | ||
|
|
||
This file was deleted.
Uh oh!
There was an error while loading. Please reload this page.