Skip to content

Commit 644e1c5

Browse files
committed
simplify accelerator steps
1 parent aeaa6b2 commit 644e1c5

File tree

4 files changed

+26
-49
lines changed

4 files changed

+26
-49
lines changed

pytorch_lightning/accelerators/cpu_accelerator.py

Lines changed: 9 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from typing import Any, Optional, Union
14+
from typing import Any, Optional, Union, Callable
1515

1616
import torch
1717

@@ -61,29 +61,22 @@ def train(self):
6161
results = self.train_or_test()
6262
return results
6363

64-
def training_step(self, args):
64+
def _step(self, model_step: Callable, args):
6565
if self.trainer.amp_backend == AMPType.NATIVE:
6666
with torch.cuda.amp.autocast():
67-
output = self.trainer.model.training_step(*args)
67+
output = model_step(*args)
6868
else:
69-
output = self.trainer.model.training_step(*args)
69+
output = model_step(*args)
7070
return output
7171

72+
def training_step(self, args):
73+
return self._step(self.trainer.model.training_step, args)
74+
7275
def validation_step(self, args):
73-
if self.trainer.amp_backend == AMPType.NATIVE:
74-
with torch.cuda.amp.autocast():
75-
output = self.trainer.model.validation_step(*args)
76-
else:
77-
output = self.trainer.model.validation_step(*args)
78-
return output
76+
return self._step(self.trainer.model.validation_step, args)
7977

8078
def test_step(self, args):
81-
if self.trainer.amp_backend == AMPType.NATIVE:
82-
with torch.cuda.amp.autocast():
83-
output = self.trainer.model.test_step(*args)
84-
else:
85-
output = self.trainer.model.test_step(*args)
86-
return output
79+
return self._step(self.trainer.model.test_step, args)
8780

8881
def sync_tensor(self,
8982
tensor: Union[torch.Tensor],

pytorch_lightning/accelerators/dp_accelerator.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -116,21 +116,22 @@ def teardown(self):
116116
self.trainer.model.forward = self.model_autocast_original_forward
117117
self.barrier()
118118

119-
def training_step(self, args):
119+
def _step(self, args):
120120
if self.trainer.amp_backend == AMPType.NATIVE:
121121
with torch.cuda.amp.autocast():
122122
output = self.trainer.model(*args)
123123
else:
124124
output = self.trainer.model(*args)
125125
return output
126126

127+
def training_step(self, args):
128+
return self._step(args)
129+
127130
def validation_step(self, args):
128-
output = self.training_step(args)
129-
return output
131+
return self._step(args)
130132

131133
def test_step(self, args):
132-
output = self.training_step(args)
133-
return output
134+
return self._step(args)
134135

135136
def training_step_end(self, output):
136137
if isinstance(output, Result):

pytorch_lightning/accelerators/horovod_accelerator.py

Lines changed: 9 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
from contextlib import ExitStack
15-
from typing import Any, Optional, Union
15+
from typing import Any, Optional, Union, Callable
1616

1717
import torch
1818
from torch.optim.lr_scheduler import _LRScheduler
@@ -114,46 +114,28 @@ def train(self):
114114
hvd.join()
115115
return results
116116

117-
def training_step(self, args):
117+
def _step(self, model_step: Callable, args):
118118
if self.trainer.on_gpu:
119119
batch = args[0]
120120
batch = self.batch_to_device(batch, hvd.local_rank())
121121
args[0] = batch
122122

123123
if self.trainer.amp_backend == AMPType.NATIVE:
124124
with torch.cuda.amp.autocast():
125-
output = self.trainer.model.training_step(*args)
125+
output = model_step(*args)
126126
else:
127-
output = self.trainer.model.training_step(*args)
127+
output = model_step(*args)
128128

129129
return output
130130

131-
def validation_step(self, args):
132-
if self.trainer.on_gpu:
133-
batch = args[0]
134-
batch = self.batch_to_device(batch, hvd.local_rank())
135-
args[0] = batch
136-
137-
if self.trainer.amp_backend == AMPType.NATIVE:
138-
with torch.cuda.amp.autocast():
139-
output = self.trainer.model.validation_step(*args)
140-
else:
141-
output = self.trainer.model.validation_step(*args)
131+
def training_step(self, args):
132+
return self._step(self.trainer.model.training_step, args)
142133

143-
return output
134+
def validation_step(self, args):
135+
return self._step(self.trainer.model.validation_step, args)
144136

145137
def test_step(self, args):
146-
if self.trainer.on_gpu:
147-
batch = args[0]
148-
batch = self.batch_to_device(batch, hvd.local_rank())
149-
args[0] = batch
150-
151-
if self.trainer.amp_backend == AMPType.NATIVE:
152-
with torch.cuda.amp.autocast():
153-
output = self.trainer.model.test_step(*args)
154-
else:
155-
output = self.trainer.model.test_step(*args)
156-
return output
138+
return self._step(self.trainer.model.test_step, args)
157139

158140
def backward(self, closure_loss, optimizer, opt_idx, *args, **kwargs):
159141
super().backward(closure_loss, optimizer, opt_idx, *args, **kwargs)

pytorch_lightning/trainer/connectors/slurm_connector.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ def configure_slurm_ddp(self, num_gpu_nodes):
5454
if self.trainer.is_slurm_managing_tasks:
5555
rank_zero_info('Multi-processing is handled by Slurm.')
5656

57+
# todo: the same function as slurm_environment.py `_resolve_root_node_address`
5758
def resolve_root_node_address(self, root_node):
5859
if '[' in root_node:
5960
name, numbers = root_node.split('[', maxsplit=1)
@@ -108,8 +109,8 @@ def term_handler(self, signum, frame):
108109
# save
109110
log.info("bypassing sigterm")
110111

112+
# todo: this is the same func as slurm_environment.py `master_port`
111113
def connect_ddp(self, global_rank: int, world_size: int) -> None:
112-
""""""
113114
"""
114115
Sets up environment variables necessary for pytorch distributed communications
115116
based on slurm environment.

0 commit comments

Comments
 (0)