Skip to content

Commit 4f876b4

Browse files
committed
fix
1 parent c99bc32 commit 4f876b4

File tree

2 files changed

+14
-14
lines changed

2 files changed

+14
-14
lines changed

pytorch_lightning/accelerators/cpu_accelerator.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -61,22 +61,22 @@ def train(self):
6161
results = self.train_or_test()
6262
return results
6363

64-
def _step(self, model_step: Callable, args):
64+
def _step(self, model_step: Callable, *args):
6565
if self.trainer.amp_backend == AMPType.NATIVE:
6666
with torch.cuda.amp.autocast():
6767
output = model_step(*args)
6868
else:
6969
output = model_step(*args)
7070
return output
7171

72-
def training_step(self, args):
73-
return self._step(self, self.trainer.model.training_step, args)
72+
def training_step(self, *args):
73+
return self._step(self.trainer.model.training_step, *args)
7474

75-
def validation_step(self, args):
76-
return self._step(self, self.trainer.model.validation_step, args)
75+
def validation_step(self, *args):
76+
return self._step(self.trainer.model.validation_step, *args)
7777

78-
def test_step(self, args):
79-
return self._step(self, self.trainer.model.test_step, args)
78+
def test_step(self, *args):
79+
return self._step(self.trainer.model.test_step, *args)
8080

8181
def sync_tensor(self,
8282
tensor: Union[torch.Tensor],

pytorch_lightning/accelerators/horovod_accelerator.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def train(self):
114114
hvd.join()
115115
return results
116116

117-
def _step(self, model_step: Callable, args):
117+
def _step(self, model_step: Callable, *args):
118118
if self.trainer.on_gpu:
119119
batch = args[0]
120120
batch = self.batch_to_device(batch, hvd.local_rank())
@@ -128,14 +128,14 @@ def _step(self, model_step: Callable, args):
128128

129129
return output
130130

131-
def training_step(self, args):
132-
return self._step(self, self.trainer.model.training_step, args)
131+
def training_step(self, *args):
132+
return self._step(self.trainer.model.training_step, *args)
133133

134-
def validation_step(self, args):
135-
return self._step(self, self.trainer.model.validation_step, args)
134+
def validation_step(self, *args):
135+
return self._step(self.trainer.model.validation_step, *args)
136136

137-
def test_step(self, args):
138-
return self._step(self, self.trainer.model.test_step, args)
137+
def test_step(self, *args):
138+
return self._step(self.trainer.model.test_step, *args)
139139

140140
def backward(self, closure_loss, optimizer, opt_idx, *args, **kwargs):
141141
super().backward(closure_loss, optimizer, opt_idx, *args, **kwargs)

0 commit comments

Comments
 (0)