Skip to content

Commit c2e6e68

Browse files
tchatoncarmoccaBordawilliamFalconSeanNaren
authored
optimizer clean up (#4658)
* add LightningOptimizer * typo * add mock closure * typo * remove logic in optimizer_step * update * update * update * desactivate LightningOptimizer for hovorod * resolve flake * typo * check optimizer name * change name * added backward to LightningOptimizer * remove use_lightning_optimizer * move update * simplify init * resolve comments * resolve bug * update * update * resolve bugs * resolve flake8 * set state * work manual_optimizer_step * add doc * add enable_pl_optimizer * make optimizer_step * add make_optimizer_step * add examples * resolve test * add test_optimizer_return_options_enable_pl_optimizer * add enable_pl_optimizer=True * update * update tests * resolve bugs * update * set Trainer to False * update * resolve bugs * update * remove from doc * resolve bug * typo * update * set to True * simplification * typo * resolve horovod * unwrap horovod * remove Optimizer * resolve horovod * move logic to amp_backend * doesn't seem to be pickable * update * add again * resolve some bugs * cleanup * resolve bug with AMP * change __repr__ * round at -12 * udpate * update * update * remove from horovod * typo * add convert_to_lightning_optimizers in each accelerators * typo * forgot * forgot a convert_to_lightning_optimizers * update * update * update * increase coverage * update * resolve flake8 * update * remove useless code * resolve comments + add support for LightningOptimizer base class * resolve flake * check optimizer get wrapped back * resolve DDPSharded * reduce code * lightningoptimizer * Update pytorch_lightning/core/optimizer.py Co-authored-by: Carlos Mocholí <[email protected]> * Update pytorch_lightning/core/lightning.py * remove reference to step function * Apply suggestions from code review * update on comments * resolve * Update CHANGELOG.md * add back training_step in apex and native_amp * rename optimizer_step Co-authored-by: Carlos Mocholí <[email protected]> Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: William Falcon <[email protected]> Co-authored-by: Sean Naren <[email protected]>
1 parent 2fe1eff commit c2e6e68

36 files changed

+741
-332
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
4444

4545
- Added printing of total num of params, trainable and non-trainable params in ModelSummary ([#4521](https://github.com/PyTorchLightning/pytorch-lightning/pull/4521))
4646

47+
- Added optimizer refactors ([#4658](https://github.com/PyTorchLightning/pytorch-lightning/pull/4658))
48+
4749

4850
### Changed
4951

docs/source/lightning_module.rst

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1009,11 +1009,6 @@ manual_backward
10091009
.. automethod:: pytorch_lightning.core.lightning.LightningModule.manual_backward
10101010
:noindex:
10111011

1012-
manual_optimizer_step
1013-
~~~~~~~~~~~~~~~~~~~~~
1014-
1015-
.. automethod:: pytorch_lightning.core.lightning.LightningModule.manual_optimizer_step
1016-
:noindex:
10171012

10181013
on_after_backward
10191014
~~~~~~~~~~~~~~~~~

docs/source/optimizers.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ to manually manage the optimization process. To do so, do the following:
3636
3737
# use self.backward which will also handle scaling the loss when using amp
3838
self.manual_backward(loss_a, opt_g)
39-
self.manual_optimizer_step(opt_g)
39+
opt_g.step()
4040
4141
4242
# do anything you want
@@ -45,7 +45,7 @@ to manually manage the optimization process. To do so, do the following:
4545
# pass in any args that loss.backward() normally takes
4646
self.manual_backward(loss_b, opt_d, retain_graph=True)
4747
self.manual_backward(loss_b, opt_d)
48-
self.manual_optimizer_step(opt_d)
48+
opt_d.step()
4949
5050
# log losses
5151
self.log('loss_a', loss_a)

pytorch_lightning/accelerators/accelerator_connector.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,18 +11,18 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from pytorch_lightning import accelerators
1514
import os
15+
1616
import torch
1717

18-
from pytorch_lightning.utilities import device_parser, XLA_AVAILABLE
19-
from pytorch_lightning.utilities import rank_zero_only
20-
from pytorch_lightning.utilities.distributed import rank_zero_warn, rank_zero_info
21-
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2218
from pytorch_lightning import _logger as log
19+
from pytorch_lightning import accelerators
20+
from pytorch_lightning.accelerators.accelerator import Accelerator
2321
from pytorch_lightning.cluster_environments.slurm_environment import SLURMEnvironment
2422
from pytorch_lightning.cluster_environments.torchelastic_environment import TorchElasticEnvironment
25-
from pytorch_lightning.accelerators.accelerator import Accelerator
23+
from pytorch_lightning.utilities import XLA_AVAILABLE, device_parser, rank_zero_only
24+
from pytorch_lightning.utilities.distributed import rank_zero_info, rank_zero_warn
25+
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2626

2727
try:
2828
import horovod.torch as hvd
@@ -397,8 +397,7 @@ def set_nvidia_flags(self, is_slurm_managing_tasks, data_parallel_device_ids):
397397
def determine_local_rank(self):
398398
if self.trainer.is_slurm_managing_tasks:
399399
return int(os.environ['SLURM_LOCALID'])
400-
else:
401-
return int(os.environ.get('LOCAL_RANK', 0))
400+
return int(os.environ.get('LOCAL_RANK', 0))
402401

403402
def determine_ddp_node_rank(self):
404403
if self.trainer.is_slurm_managing_tasks:

pytorch_lightning/accelerators/cpu_accelerator.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from typing import Optional, Union, Any
14+
from typing import Any, Optional, Union
1515

1616
import torch
1717

@@ -47,6 +47,8 @@ def setup(self, model):
4747
# allow for lr schedulers as well
4848
self.setup_optimizers(model)
4949

50+
self.trainer.convert_to_lightning_optimizers()
51+
5052
self.trainer.model = model
5153

5254
def train(self):

pytorch_lightning/accelerators/ddp2_accelerator.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,23 +12,23 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License
1414
import os
15+
from typing import Any, List, Optional, Union
1516

1617
import torch
1718
import torch.distributed as torch_distrib
19+
from torch.nn.parallel import DistributedDataParallel
1820

21+
from pytorch_lightning import _logger as log
22+
from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp
1923
from pytorch_lightning.core.lightning import LightningModule
2024
from pytorch_lightning.core.step_result import Result
2125
from pytorch_lightning.distributed.dist import LightningDistributed
22-
from pytorch_lightning import _logger as log
23-
from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp
24-
from pytorch_lightning.utilities import AMPType, HYDRA_AVAILABLE
26+
from pytorch_lightning.utilities import HYDRA_AVAILABLE, AMPType
2527
from pytorch_lightning.utilities.distributed import rank_zero_only, sync_ddp_if_available
26-
from torch.nn.parallel import DistributedDataParallel
27-
from typing import List, Optional, Union, Any
2828

2929
if HYDRA_AVAILABLE:
30-
from hydra.utils import to_absolute_path, get_original_cwd
3130
from hydra.core.hydra_config import HydraConfig
31+
from hydra.utils import get_original_cwd, to_absolute_path
3232

3333

3434
class DDP2Accelerator(Accelerator):
@@ -170,6 +170,8 @@ def ddp_train(self, process_idx, mp_queue, model):
170170
# 16-bit
171171
model = self.trainer.precision_connector.connect(model)
172172

173+
self.trainer.convert_to_lightning_optimizers()
174+
173175
# device ids change depending on the DDP setup
174176
device_ids = self.get_device_ids()
175177

pytorch_lightning/accelerators/ddp_accelerator.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,32 +12,29 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License
1414
import os
15-
import torch
16-
import torch.distributed as torch_distrib
1715
import subprocess
1816
import sys
1917
from os.path import abspath
2018
from time import sleep
21-
from typing import Any, Optional, List, Union
19+
from typing import Any, List, Optional, Union
2220

2321
import numpy as np
22+
import torch
23+
import torch.distributed as torch_distrib
24+
from torch.nn.parallel import DistributedDataParallel
2425

2526
from pytorch_lightning import _logger as log
2627
from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp
2728
from pytorch_lightning.core.lightning import LightningModule
2829
from pytorch_lightning.distributed.dist import LightningDistributed
29-
from pytorch_lightning.utilities import AMPType, HYDRA_AVAILABLE
30-
from pytorch_lightning.utilities.distributed import find_free_network_port
31-
from pytorch_lightning.utilities.distributed import rank_zero_only
32-
from pytorch_lightning.utilities.distributed import sync_ddp_if_available
30+
from pytorch_lightning.utilities import HYDRA_AVAILABLE, AMPType
31+
from pytorch_lightning.utilities.distributed import find_free_network_port, rank_zero_only, sync_ddp_if_available
3332
from pytorch_lightning.utilities.exceptions import MisconfigurationException
3433
from pytorch_lightning.utilities.seed import seed_everything
35-
from torch.nn.parallel import DistributedDataParallel
36-
3734

3835
if HYDRA_AVAILABLE:
39-
from hydra.utils import to_absolute_path, get_original_cwd
4036
from hydra.core.hydra_config import HydraConfig
37+
from hydra.utils import get_original_cwd, to_absolute_path
4138

4239

4340
class DDPAccelerator(Accelerator):
@@ -266,6 +263,8 @@ def ddp_train(self, process_idx, model):
266263
# 16-bit
267264
model = self.trainer.precision_connector.connect(model)
268265

266+
self.trainer.convert_to_lightning_optimizers()
267+
269268
# device ids change depending on the DDP setup
270269
device_ids = self.get_device_ids()
271270

pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,14 @@
2323
from pytorch_lightning import _logger as log
2424
from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp
2525
from pytorch_lightning.core.lightning import LightningModule
26-
from pytorch_lightning.utilities import AMPType, HYDRA_AVAILABLE
27-
from pytorch_lightning.utilities.distributed import rank_zero_only, rank_zero_warn
28-
from pytorch_lightning.utilities.distributed import find_free_network_port, sync_ddp_if_available
2926
from pytorch_lightning.distributed.dist import LightningDistributed
27+
from pytorch_lightning.utilities import HYDRA_AVAILABLE, AMPType
28+
from pytorch_lightning.utilities.distributed import (
29+
find_free_network_port,
30+
rank_zero_only,
31+
rank_zero_warn,
32+
sync_ddp_if_available,
33+
)
3034

3135
if HYDRA_AVAILABLE:
3236
from hydra.core.hydra_config import HydraConfig
@@ -130,6 +134,8 @@ def ddp_train(self, process_idx, mp_queue, model):
130134
# 16-bit
131135
model = self.trainer.precision_connector.connect(model)
132136

137+
self.trainer.convert_to_lightning_optimizers()
138+
133139
# DDP spawn already spawned off each process... no need to do anything
134140
device_ids = self.get_device_ids()
135141

pytorch_lightning/accelerators/ddp_hpc_accelerator.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,12 @@
2323
from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp
2424
from pytorch_lightning.core.lightning import LightningModule
2525
from pytorch_lightning.distributed.dist import LightningDistributed
26-
from pytorch_lightning.utilities import AMPType, HYDRA_AVAILABLE
26+
from pytorch_lightning.utilities import HYDRA_AVAILABLE, AMPType
2727
from pytorch_lightning.utilities.distributed import rank_zero_only, sync_ddp_if_available
2828

29-
3029
if HYDRA_AVAILABLE:
31-
from hydra.utils import to_absolute_path, get_original_cwd
3230
from hydra.core.hydra_config import HydraConfig
31+
from hydra.utils import get_original_cwd, to_absolute_path
3332

3433

3534
class DDPHPCAccelerator(Accelerator):
@@ -164,6 +163,8 @@ def ddp_train(self, process_idx, model):
164163
# 16-bit
165164
model = self.trainer.precision_connector.connect(model)
166165

166+
self.trainer.convert_to_lightning_optimizers()
167+
167168
# device ids change depending on the DDP setup
168169
device_ids = self.get_device_ids()
169170

pytorch_lightning/accelerators/ddp_spawn_accelerator.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,20 +16,25 @@
1616
from typing import Any, List, Optional, Union
1717

1818
import torch
19-
import torch.multiprocessing as mp
2019
import torch.distributed as torch_distrib
2120
import torch.distributed as dist
21+
import torch.multiprocessing as mp
2222
from torch.nn.parallel import DistributedDataParallel
2323

2424
from pytorch_lightning import _logger as log
2525
from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp
2626
from pytorch_lightning.core.lightning import LightningModule
27-
from pytorch_lightning.utilities import AMPType, HYDRA_AVAILABLE
28-
from pytorch_lightning.utilities.cloud_io import atomic_save, load as pl_load
29-
from pytorch_lightning.utilities.distributed import rank_zero_only, rank_zero_warn, find_free_network_port
30-
from pytorch_lightning.utilities.distributed import sync_ddp_if_available
27+
from pytorch_lightning.distributed import LightningDistributed
28+
from pytorch_lightning.utilities import HYDRA_AVAILABLE, AMPType
29+
from pytorch_lightning.utilities.cloud_io import atomic_save
30+
from pytorch_lightning.utilities.cloud_io import load as pl_load
31+
from pytorch_lightning.utilities.distributed import (
32+
find_free_network_port,
33+
rank_zero_only,
34+
rank_zero_warn,
35+
sync_ddp_if_available,
36+
)
3137
from pytorch_lightning.utilities.seed import seed_everything
32-
from pytorch_lightning.distributed.dist import LightningDistributed
3338

3439
if HYDRA_AVAILABLE:
3540
from hydra.core.hydra_config import HydraConfig
@@ -141,6 +146,8 @@ def ddp_train(self, process_idx, mp_queue, model, is_master=False, proc_offset=0
141146
# 16-bit
142147
model = self.trainer.precision_connector.connect(model)
143148

149+
self.trainer.convert_to_lightning_optimizers()
150+
144151
# device ids change depending on the DDP setup
145152
device_ids = self.get_device_ids()
146153

0 commit comments

Comments
 (0)