-
Notifications
You must be signed in to change notification settings - Fork 3.6k
[bugfix] Apex never instantiated. #7274
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
ba35a45
7e86c43
d097b34
940132d
c8ad9c7
796eb05
c4543ba
bd4523c
8a75b52
7d00bd3
40492d6
2144453
82ab406
ae27b60
a071044
d28e8b6
875f7ea
b48017d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,3 +1,17 @@ | ||
| # Copyright The PyTorch Lightning team. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| import os | ||
| from unittest import mock | ||
|
|
||
|
|
@@ -37,7 +51,7 @@ class MyApexPlugin(ApexMixedPrecisionPlugin): | |
| pytest.param('native', False, NativeMixedPrecisionPlugin, marks=RunIf(amp_native=True)), | ||
| pytest.param('native', True, MyNativeAMP, marks=RunIf(amp_native=True)), | ||
| pytest.param('apex', False, ApexMixedPrecisionPlugin, marks=RunIf(amp_apex=True)), | ||
| pytest.param('apex', True, MyApexPlugin, marks=RunIf(amp_apex=True)) | ||
| pytest.param('apex', True, MyApexPlugin, marks=RunIf(amp_apex=True)), | ||
| ] | ||
| ) | ||
| def test_amp_apex_ddp( | ||
|
|
@@ -83,3 +97,47 @@ def test_amp_gradient_unscale(tmpdir, accum: int): | |
| accumulate_grad_batches=accum, | ||
| ) | ||
| trainer.fit(model) | ||
|
|
||
|
|
||
| @RunIf(min_gpus=2, amp_apex=True, special=True) | ||
| @pytest.mark.parametrize("amp_level", ['O2']) | ||
| def test_amp_apex_ddp_fit(amp_level, tmpdir): | ||
|
|
||
|
Comment on lines
+103
to
+105
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We have another apex test in
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes. |
||
| class CustomBoringModel(BoringModel): | ||
|
|
||
| def training_step(self, batch, batch_idx): | ||
| assert self.layer.weight.dtype == torch.float16 | ||
| assert self.trainer.precision_plugin._connected | ||
| return super().training_step(batch, batch_idx) | ||
|
|
||
| trainer = Trainer( | ||
| default_root_dir=tmpdir, | ||
| fast_dev_run=True, | ||
| precision=16, | ||
| amp_backend="apex", | ||
| gpus=2, | ||
| accelerator='ddp', | ||
| plugins=ApexMixedPrecisionPlugin(amp_level=amp_level), | ||
| ) | ||
| assert isinstance(trainer.precision_plugin, ApexMixedPrecisionPlugin) | ||
| model = CustomBoringModel() | ||
| trainer.fit(model) | ||
| trainer.test(model) | ||
|
|
||
|
|
||
| @RunIf(min_gpus=2, amp_apex=True) | ||
| @pytest.mark.parametrize("amp_level", ['O2']) | ||
| def test_amp_apex_ddp_spawn_fit(amp_level, tmpdir): | ||
|
|
||
| trainer = Trainer( | ||
| default_root_dir=tmpdir, | ||
| fast_dev_run=True, | ||
| precision=16, | ||
| amp_backend="apex", | ||
| gpus=2, | ||
| accelerator='ddp_spawn', | ||
| plugins=ApexMixedPrecisionPlugin(amp_level=amp_level), | ||
| ) | ||
| assert isinstance(trainer.precision_plugin, ApexMixedPrecisionPlugin) | ||
| model = BoringModel() | ||
| trainer.fit(model) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It might be a good idea to clear up somewhere here that this happens after accelerator setup? Otherwise this looks the same as
pre_dispatchThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@tchaton
As the order of hooks being executed could be confusing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i find the pre/dispatch/post confusing now :/
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, we should think about the naming of these hooks. But more importantly, I think we can do a better job at formally defining what these hook are supposed to do. Maybe another action item for 1.3 is to do a full pass over the plugins and improve all these docs. That would help everyone 1. implementing plugins 2. fix 3. review Plugin PRs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
n00b question: would this be easier if the precision plugin was owned by the training type plugin instead of the accelerator?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I think that could make it easier to interleave these operations between one plugin and the other. Here in this PR we see that the precision plugin needs to configure the model before it is wrapped, and needs to overwrite the reference in the training plugin. this really breaks the contract that these plugins currently have with each other.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, we should definitely refactor this and move optimizers, lr_schedulers to the training_type_plugin.