Skip to content

Commit f17967e

Browse files
authored
Merge branch 'master' into maxjeblick/master
2 parents a7724bf + d5fa02e commit f17967e

File tree

4 files changed

+27
-52
lines changed

4 files changed

+27
-52
lines changed

pytorch_lightning/accelerators/cpu_accelerator.py

Lines changed: 9 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from typing import Any, Optional, Union
14+
from typing import Any, Optional, Union, Callable
1515

1616
import torch
1717

@@ -62,29 +62,22 @@ def train(self):
6262
results = self.train_or_test()
6363
return results
6464

65-
def training_step(self, args):
65+
def _step(self, model_step: Callable, args):
6666
if self.trainer.amp_backend == AMPType.NATIVE:
6767
with torch.cuda.amp.autocast():
68-
output = self.trainer.model.training_step(*args)
68+
output = model_step(*args)
6969
else:
70-
output = self.trainer.model.training_step(*args)
70+
output = model_step(*args)
7171
return output
7272

73+
def training_step(self, args):
74+
return self._step(self.trainer.model.training_step, args)
75+
7376
def validation_step(self, args):
74-
if self.trainer.amp_backend == AMPType.NATIVE:
75-
with torch.cuda.amp.autocast():
76-
output = self.trainer.model.validation_step(*args)
77-
else:
78-
output = self.trainer.model.validation_step(*args)
79-
return output
77+
return self._step(self.trainer.model.validation_step, args)
8078

8179
def test_step(self, args):
82-
if self.trainer.amp_backend == AMPType.NATIVE:
83-
with torch.cuda.amp.autocast():
84-
output = self.trainer.model.test_step(*args)
85-
else:
86-
output = self.trainer.model.test_step(*args)
87-
return output
80+
return self._step(self.trainer.model.test_step, args)
8881

8982
def sync_tensor(self,
9083
tensor: Union[torch.Tensor],

pytorch_lightning/accelerators/dp_accelerator.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -118,21 +118,22 @@ def teardown(self):
118118
self.trainer.model.forward = self.model_autocast_original_forward
119119
self.barrier()
120120

121-
def training_step(self, args):
121+
def _step(self, args):
122122
if self.trainer.amp_backend == AMPType.NATIVE:
123123
with torch.cuda.amp.autocast():
124124
output = self.trainer.model(*args)
125125
else:
126126
output = self.trainer.model(*args)
127127
return output
128128

129+
def training_step(self, args):
130+
return self._step(args)
131+
129132
def validation_step(self, args):
130-
output = self.training_step(args)
131-
return output
133+
return self._step(args)
132134

133135
def test_step(self, args):
134-
output = self.training_step(args)
135-
return output
136+
return self._step(args)
136137

137138
def training_step_end(self, output):
138139
if isinstance(output, Result):

pytorch_lightning/accelerators/horovod_accelerator.py

Lines changed: 10 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
from contextlib import ExitStack
15-
from typing import Any, Optional, Union
15+
from typing import Any, Optional, Union, Callable
1616

1717
import torch
1818
from torch.optim.lr_scheduler import _LRScheduler
@@ -116,46 +116,26 @@ def train(self):
116116
hvd.join()
117117
return results
118118

119-
def training_step(self, args):
119+
def _step(self, model_step: Callable, args):
120120
if self.trainer.on_gpu:
121-
batch = args[0]
122-
batch = self.batch_to_device(batch, hvd.local_rank())
123-
args[0] = batch
121+
args[0] = self.batch_to_device(args[0], hvd.local_rank())
124122

125123
if self.trainer.amp_backend == AMPType.NATIVE:
126124
with torch.cuda.amp.autocast():
127-
output = self.trainer.model.training_step(*args)
125+
output = model_step(*args)
128126
else:
129-
output = self.trainer.model.training_step(*args)
127+
output = model_step(*args)
130128

131129
return output
132130

133-
def validation_step(self, args):
134-
if self.trainer.on_gpu:
135-
batch = args[0]
136-
batch = self.batch_to_device(batch, hvd.local_rank())
137-
args[0] = batch
138-
139-
if self.trainer.amp_backend == AMPType.NATIVE:
140-
with torch.cuda.amp.autocast():
141-
output = self.trainer.model.validation_step(*args)
142-
else:
143-
output = self.trainer.model.validation_step(*args)
131+
def training_step(self, args):
132+
return self._step(self.trainer.model.training_step, args)
144133

145-
return output
134+
def validation_step(self, args):
135+
return self._step(self.trainer.model.validation_step, args)
146136

147137
def test_step(self, args):
148-
if self.trainer.on_gpu:
149-
batch = args[0]
150-
batch = self.batch_to_device(batch, hvd.local_rank())
151-
args[0] = batch
152-
153-
if self.trainer.amp_backend == AMPType.NATIVE:
154-
with torch.cuda.amp.autocast():
155-
output = self.trainer.model.test_step(*args)
156-
else:
157-
output = self.trainer.model.test_step(*args)
158-
return output
138+
return self._step(self.trainer.model.test_step, args)
159139

160140
def backward(self, closure_loss, optimizer, opt_idx, *args, **kwargs):
161141
super().backward(closure_loss, optimizer, opt_idx, *args, **kwargs)

pytorch_lightning/trainer/connectors/slurm_connector.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ def configure_slurm_ddp(self, num_gpu_nodes):
5454
if self.trainer.is_slurm_managing_tasks:
5555
rank_zero_info('Multi-processing is handled by Slurm.')
5656

57+
# todo: the same function as slurm_environment.py `_resolve_root_node_address`
5758
def resolve_root_node_address(self, root_node):
5859
if '[' in root_node:
5960
name, numbers = root_node.split('[', maxsplit=1)
@@ -108,8 +109,8 @@ def term_handler(self, signum, frame):
108109
# save
109110
log.info("bypassing sigterm")
110111

112+
# todo: this is the same func as slurm_environment.py `master_port`
111113
def connect_ddp(self, global_rank: int, world_size: int) -> None:
112-
""""""
113114
"""
114115
Sets up environment variables necessary for pytorch distributed communications
115116
based on slurm environment.

0 commit comments

Comments
 (0)