Skip to content

Commit 7c90913

Browse files
dvolgyestchatonBorda
authored andcommitted
Fix for incorrect usage of detach(), cpu(), to() (Lightning-AI#6216)
* Fix for incorrect detach/cpu calls (Lightning-AI#6214) * Fix incorrect use of detach(), to(), and cpu(), Lightning-AI#6214 * Fix incorrect use of detach() and cpu(), Lightning-AI#6214 * update pr * add typing * chlog * more... * revert on module * update on comments * revert changes on model Co-authored-by: tchaton <[email protected]> Co-authored-by: Jirka Borovec <[email protected]>
1 parent 548241b commit 7c90913

File tree

6 files changed

+20
-15
lines changed

6 files changed

+20
-15
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
3232
- Fixed multiple early stopping callbacks ([#6197](https://github.com/PyTorchLightning/pytorch-lightning/pull/6197))
3333

3434

35+
- Fixed incorrect usage of `detach()`, `cpu()`, `to()` ([#6216](https://github.com/PyTorchLightning/pytorch-lightning/pull/6216))
36+
37+
3538
- Fixed LBFGS optimizer support which didn't converge in automatic optimization ([#6147](https://github.com/PyTorchLightning/pytorch-lightning/pull/6147))
3639

3740

pytorch_lightning/core/step_result.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -416,20 +416,22 @@ def get_batch_pbar_metrics(self, include_forked_originals=True, add_dataloader_i
416416

417417
return result
418418

419-
def detach(self):
419+
def detach(self) -> 'Result':
420420
for k, v in self.items():
421421
if isinstance(v, torch.Tensor):
422422
self.__setitem__(k, v.detach())
423+
return self
423424

424-
def to(self, *args, **kwargs):
425+
def to(self, *args, **kwargs) -> 'Result':
425426
"""Move all self attributes to the given device."""
426427
for k, v in self.items():
427428
if isinstance(v, torch.Tensor):
428429
self.__setitem__(k, v.to(*args, **kwargs))
430+
return self
429431

430-
def cpu(self):
432+
def cpu(self) -> 'Result':
431433
"""Move all self attributes to CPU."""
432-
self.to(torch.device("cpu"))
434+
return self.to(torch.device("cpu"))
433435

434436
def __repr__(self):
435437
self_copy = self.copy()

pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -281,11 +281,11 @@ def cache_result(self) -> None:
281281
# attach capture batch_size
282282
Result.attach_batch_size(self._batch_size, hook_result)
283283

284-
hook_result.detach()
284+
hook_result = hook_result.detach()
285285
if self.trainer.move_metrics_to_cpu:
286-
hook_result.cpu()
286+
hook_result = hook_result.cpu()
287287
elif self.trainer._distrib_type == DistributedType.DP:
288-
hook_result.to(torch.device("cuda", self.trainer.root_gpu))
288+
hook_result = hook_result.to(torch.device("cuda", self.trainer.root_gpu))
289289

290290
self._internals[fx_name].append(hook_result, info)
291291

pytorch_lightning/trainer/trainer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -773,9 +773,9 @@ def run_evaluation(self, max_batches=None, on_epoch=False):
773773
def track_output_for_epoch_end(self, outputs, output):
774774
if output is not None:
775775
if isinstance(output, Result):
776-
output.detach()
776+
output = output.detach()
777777
if self.move_metrics_to_cpu:
778-
output.cpu()
778+
output = output.cpu()
779779
elif isinstance(output, dict):
780780
output = recursive_detach(output, to_cpu=self.move_metrics_to_cpu)
781781
elif isinstance(output, torch.Tensor) and output.is_cuda and self.move_metrics_to_cpu:

pytorch_lightning/trainer/training_loop.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,7 @@ def on_after_backward(self, training_step_output, batch_idx, untouched_loss):
263263
is_result_obj = isinstance(training_step_output, Result)
264264

265265
if is_result_obj:
266-
training_step_output.detach()
266+
training_step_output = training_step_output.detach()
267267
else:
268268
training_step_output.batch_loss = training_step_output.batch_loss.detach()
269269

@@ -397,9 +397,9 @@ def _process_training_step_output_1_0(self, training_step_output, split_batch):
397397

398398
# track metrics without grads for epoch reduction
399399
training_step_output_for_epoch_end = copy(result)
400-
training_step_output_for_epoch_end.detach()
400+
training_step_output_for_epoch_end = training_step_output_for_epoch_end.detach()
401401
if self.trainer.move_metrics_to_cpu:
402-
training_step_output_for_epoch_end.cpu()
402+
training_step_output_for_epoch_end = training_step_output_for_epoch_end.cpu()
403403

404404
# what flows back into the system
405405
training_step_output = result

tests/overrides/test_data_parallel.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -144,9 +144,9 @@ def training_step(self, batch, batch_idx):
144144
output.update({"python scalar": 12.3})
145145
return output
146146

147-
model = TestModel()
148-
model.to(device)
149-
model.running_stage = RunningStage.TRAINING
147+
model = TestModel().to(device)
148+
model.trainer = MagicMock()
149+
model.trainer._running_stage = RunningStage.TRAINING
150150
batch = torch.rand(2, 32).to(device)
151151
batch_idx = 0
152152

0 commit comments

Comments
 (0)