Skip to content

Commit b71aa55

Browse files
yifuwangYifu Wangcarmocca
authored
Make optimizers skippable when using amp (#7975)
Co-authored-by: Yifu Wang <yifuwang@[email protected]> Co-authored-by: Carlos Mocholi <[email protected]>
1 parent 0004216 commit b71aa55

File tree

3 files changed

+52
-6
lines changed

3 files changed

+52
-6
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
300300
- Fixed training loop total batch counter when accumulate grad batches was enabled ([#7692](https://github.com/PyTorchLightning/pytorch-lightning/pull/7692))
301301

302302

303+
- Fixed a bug where skipping an optimizer while using amp causes amp to trigger an assertion error ([#7975](https://github.com/PyTorchLightning/pytorch-lightning/pull/7975))
304+
305+
303306
## [1.3.2] - 2021-05-18
304307

305308
### Changed

pytorch_lightning/plugins/precision/native_amp.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -83,19 +83,21 @@ def pre_optimizer_step(
8383
f"native PyTorch amp and lbfgs are not compatible (optimizer {optimizer_idx})."
8484
" To request, please file a Github issue in PyTorch and tag @mcarilli"
8585
)
86-
lambda_closure()
8786

8887
if not pl_module.automatic_optimization:
8988
self.scaler.unscale_(optimizer)
9089
pl_module.trainer.call_hook("on_after_backward")
90+
self.scaler.step(optimizer)
91+
self.scaler.update()
92+
else:
93+
result = lambda_closure()
94+
# lambda_closure returning None indicates that backward has been skipped
95+
if result is not None:
96+
self.scaler.step(optimizer)
97+
self.scaler.update()
9198

9299
return False
93100

94-
def post_optimizer_step(self, optimizer: Optimizer, optimizer_idx: int) -> None:
95-
"""Updates the GradScaler"""
96-
self.scaler.step(optimizer)
97-
self.scaler.update()
98-
99101
@contextmanager
100102
def train_step_context(self) -> Generator[None, None, None]:
101103
"""Enable autocast context"""

tests/plugins/test_amp_plugins.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,47 @@ def test_amp_gradient_unscale(tmpdir, accum: int):
9999
trainer.fit(model)
100100

101101

102+
@RunIf(min_gpus=1, amp_native=True)
103+
def test_amp_skip_optimizer(tmpdir):
104+
"""
105+
Test that optimizers can be skipped when using amp
106+
"""
107+
108+
class CustomBoringModel(BoringModel):
109+
110+
def __init__(self):
111+
super().__init__()
112+
self.layer1 = torch.nn.Linear(32, 32)
113+
self.layer2 = torch.nn.Linear(32, 2)
114+
115+
def forward(self, x: torch.Tensor):
116+
x = self.layer1(x)
117+
x = self.layer2(x)
118+
return x
119+
120+
def training_step(self, batch, batch_idx, optimizer_idx):
121+
if optimizer_idx == 1:
122+
return None
123+
output = self(batch)
124+
return self.loss(batch, output)
125+
126+
def configure_optimizers(self):
127+
return [
128+
torch.optim.SGD(self.layer1.parameters(), lr=0.1),
129+
torch.optim.SGD(self.layer2.parameters(), lr=0.1),
130+
]
131+
132+
trainer = Trainer(
133+
default_root_dir=tmpdir,
134+
gpus=1,
135+
fast_dev_run=1,
136+
amp_backend='native',
137+
precision=16,
138+
)
139+
model = CustomBoringModel()
140+
trainer.fit(model)
141+
142+
102143
@RunIf(min_gpus=2, amp_apex=True, special=True)
103144
@pytest.mark.parametrize("amp_level", ['O2'])
104145
def test_amp_apex_ddp_fit(amp_level, tmpdir):

0 commit comments

Comments
 (0)