-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Accelerator Refactor: Precision Plugins #5718
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
14 commits
Select commit
Hold shift + click to select a range
a6e6e20
add basic accelerator class.
justusschock 3c5b0a6
add basic trainign type plugin.
justusschock 5ab4b9e
pep8
justusschock ab4660d
update copyright
justusschock 1ade664
add apex_amp
justusschock bcd0e7a
add mixed base class
justusschock a287567
add native amp
justusschock 1ff9b42
add native amp sharded
justusschock 1b82554
add tpu bfloat
justusschock 791fe19
add inits
justusschock 94e0b28
Update precision_plugin.py
justusschock 97eab50
Update pytorch_lightning/plugins/base_plugin.py
justusschock de8fe1b
fix imports
justusschock 60445ad
Merge branch 'ref/precision_plugins' of github.com:PyTorchLightning/p…
justusschock File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1 +1,6 @@ | ||
|
|
||
| from pytorch_lightning.plugins.precision.apex_amp import ApexMixedPrecisionPlugin | ||
| from pytorch_lightning.plugins.precision.mixed import MixedPrecisionPlugin | ||
| from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin | ||
| from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin | ||
| from pytorch_lightning.plugins.precision.sharded_native_amp import ShardedNativeMixedPrecisionPlugin | ||
| from pytorch_lightning.plugins.precision.tpu_bfloat import TPUHalfPrecisionPlugin |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,146 @@ | ||
| # Copyright The PyTorch Lightning team. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| from typing import List, Tuple | ||
|
|
||
| import torch | ||
| from torch.optim import Optimizer | ||
|
|
||
| from pytorch_lightning.core import LightningModule | ||
| from pytorch_lightning.plugins.precision.mixed import MixedPrecisionPlugin | ||
| from pytorch_lightning.utilities import _APEX_AVAILABLE, AMPType, rank_zero_warn | ||
|
|
||
| if _APEX_AVAILABLE: | ||
| from apex import amp | ||
|
|
||
|
|
||
| class ApexMixedPrecisionPlugin(MixedPrecisionPlugin): | ||
| """Mixed Precision Plugin based on Nvidia/Apex (https://github.com/NVIDIA/apex)""" | ||
|
|
||
| def __init__(self, amp_level: str): | ||
| self.backend = AMPType.APEX | ||
| self.amp_level = amp_level | ||
|
|
||
| def master_params(self, optimizer: torch.optim.Optimizer): | ||
| return amp.master_params(optimizer) | ||
|
|
||
| def connect(self, model: torch.nn.Module, optimizers, lr_schedulers): | ||
| """Connects the precision plugin to the training process, | ||
| configures apex and reinits the schedulers | ||
| """ | ||
| model, optimizers = self.configure_apex(amp, model, optimizers, self.amp_level) | ||
| self.reinit_scheduler_properties(optimizers, lr_schedulers) | ||
| return model, optimizers, lr_schedulers | ||
|
|
||
| def backward( | ||
| self, | ||
| model: LightningModule, | ||
| closure_loss: torch.Tensor, | ||
| optimizer: torch.optim.Optimizer, | ||
| opt_idx: int, | ||
| should_accumulate: bool, | ||
| *args, | ||
| **kwargs, | ||
| ): | ||
| """performs the actual backpropagation | ||
|
|
||
| Args: | ||
| model: the model to be optimized | ||
| closure_loss: the loss value obtained from the closure | ||
| optimizer: the optimizer to perform the step lateron | ||
| opt_idx: the optimizer's index | ||
| should_accumulate: whether to accumulate gradients or not | ||
|
|
||
| """ | ||
| closure_loss = amp.scale_loss(closure_loss, optimizer) | ||
|
|
||
| # enter apex context | ||
| context = closure_loss | ||
| closure_loss = closure_loss.__enter__() | ||
|
|
||
| # do backward pass | ||
| # TODO: not entirely sure, why we need this | ||
| if model is not None and isinstance(model, LightningModule): | ||
| model.backward(closure_loss, optimizer, opt_idx) | ||
| else: | ||
| closure_loss.backward(*args, **kwargs) | ||
|
|
||
| # exit amp context | ||
| a, b, c = None, None, None | ||
| error = context.__exit__(a, b, c) | ||
| if error: | ||
| rank_zero_warn(a, b, c) | ||
| raise Exception("apex unscale error") | ||
|
|
||
| # once backward has been applied, release graph | ||
| closure_loss = closure_loss.detach() | ||
| return closure_loss | ||
|
|
||
| def configure_apex( | ||
| self, | ||
| amp: object, | ||
| model: LightningModule, | ||
| optimizers: List[Optimizer], | ||
| amp_level: str, | ||
| ) -> Tuple[LightningModule, List[Optimizer]]: | ||
| r""" | ||
| Override to init AMP your own way. | ||
| Must return a model and list of optimizers. | ||
|
|
||
| Args: | ||
| amp: pointer to amp library object. | ||
| model: pointer to current :class:`LightningModule`. | ||
| optimizers: list of optimizers passed in :meth:`configure_optimizers`. | ||
| amp_level: AMP mode chosen ('O1', 'O2', etc...) | ||
|
|
||
| Return: | ||
| Apex wrapped model and optimizers | ||
|
|
||
| Examples: | ||
| .. code-block:: python | ||
|
|
||
| # Default implementation used by Trainer. | ||
| def configure_apex(self, amp, model, optimizers, amp_level): | ||
| model, optimizers = amp.initialize( | ||
| model, optimizers, opt_level=amp_level, | ||
| ) | ||
|
|
||
| return model, optimizers | ||
| """ | ||
| model, optimizers = amp.initialize(model, optimizers, opt_level=amp_level) | ||
| return model, optimizers | ||
|
|
||
| @staticmethod | ||
| def reinit_scheduler_properties(optimizers: list, schedulers: list): | ||
| """Reinitializes schedulers with correct properties""" | ||
| # Reinitialize optimizer.step properties added by schedulers | ||
| for scheduler in schedulers: | ||
| scheduler = scheduler["scheduler"] | ||
|
|
||
| for optimizer in optimizers: | ||
| state = None | ||
| idx = 0 | ||
|
|
||
| # check that we dont mix users optimizers and schedulers | ||
| if scheduler.optimizer == optimizer: | ||
| # Find the mro belonging to the base lr scheduler class | ||
| for i, mro in enumerate(scheduler.__class__.__mro__): | ||
| if mro in (torch.optim.lr_scheduler._LRScheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): | ||
| idx = i | ||
| state = scheduler.state_dict() | ||
| else: | ||
| state = None | ||
|
|
||
| scheduler.__class__.__mro__[idx].__init__(scheduler, optimizer) | ||
| if state is not None: | ||
| scheduler.load_state_dict(state) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,23 @@ | ||
| # Copyright The PyTorch Lightning team. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin | ||
| from pytorch_lightning.utilities import AMPType | ||
|
|
||
|
|
||
| class MixedPrecisionPlugin(PrecisionPlugin): | ||
| """Base Class for mixed precision""" | ||
|
|
||
| EPSILON = 1e-5 | ||
| backend: AMPType | ||
| precision = "mixed" | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.