Skip to content

Commit 677bb52

Browse files
authored
Revert "Remove unused param tpu_core_idx (#1948)"
This reverts commit d0ec11b.
1 parent c967b88 commit 677bb52

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

pytorch_lightning/trainer/distrib_parts.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -501,7 +501,7 @@ def single_gpu_train(self, model):
501501

502502
self.run_pretrain_routine(model)
503503

504-
def tpu_train(self, model):
504+
def tpu_train(self, tpu_core_idx, model):
505505
# put model on tpu
506506
self._device = xm.xla_device(self.tpu_id) if self.tpu_id is not None else xm.xla_device()
507507
model.to(self._device)

pytorch_lightning/trainer/trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -900,7 +900,7 @@ def fit(
900900

901901
# train
902902
if self.tpu_id is not None:
903-
self.tpu_train(model)
903+
self.tpu_train(self.tpu_id, model)
904904
else:
905905
xmp.spawn(self.tpu_train, args=(model,), nprocs=self.tpu_cores, start_method=start_method)
906906

0 commit comments

Comments
 (0)