Skip to content

Commit f0c5479

Browse files
authored
Remove legacy Result parameters (#6016)
1 parent 0e45220 commit f0c5479

File tree

13 files changed

+35
-184
lines changed

13 files changed

+35
-184
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
149149
- Removed `mode='auto'` from `EarlyStopping` ([#6167](https://github.com/PyTorchLightning/pytorch-lightning/pull/6167))
150150

151151

152+
- Removed legacy references for magic keys in the `Result` object ([#6016](https://github.com/PyTorchLightning/pytorch-lightning/pull/6016))
153+
154+
152155
- Removed deprecated `LightningModule` `hparams` setter ([#6207](https://github.com/PyTorchLightning/pytorch-lightning/pull/6207))
153156

154157

docs/source/common/lightning_module.rst

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -602,7 +602,6 @@ For cases like production, you might want to iterate different models inside a L
602602
loss = F.cross_entropy(y_hat, y)
603603
acc = FM.accuracy(y_hat, y)
604604
605-
# loss is tensor. The Checkpoint Callback is monitoring 'checkpoint_on'
606605
metrics = {'val_acc': acc, 'val_loss': loss}
607606
self.log_dict(metrics)
608607
return metrics

docs/source/common/trainer.rst

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1478,15 +1478,9 @@ with the hidden
14781478
def training_step(self, batch, batch_idx, hiddens):
14791479
# hiddens are the hiddens from the previous truncated backprop step
14801480
out, hiddens = self.lstm(data, hiddens)
1481-
1482-
# remember to detach() hiddens.
1483-
# If you don't, you will get a RuntimeError: Trying to backward through
1484-
# the graph a second time...
1485-
# Using hiddens.detach() allows each split to be disconnected.
1486-
14871481
return {
14881482
"loss": ...,
1489-
"hiddens": hiddens # remember to detach() this
1483+
"hiddens": hiddens
14901484
}
14911485
14921486
To modify how the batch is split,

pytorch_lightning/callbacks/early_stopping.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,6 @@ def __init__(
9090
self.wait_count = 0
9191
self.stopped_epoch = 0
9292
self.mode = mode
93-
self.warned_result_obj = False
9493

9594
if self.mode not in self.mode_dict:
9695
raise MisconfigurationException(f"`mode` can be {', '.join(self.mode_dict.keys())}, got {self.mode}")

pytorch_lightning/callbacks/model_checkpoint.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,6 @@ def __init__(
202202
self.best_model_path = ""
203203
self.last_model_path = ""
204204
self.save_function = None
205-
self.warned_result_obj = False
206205

207206
self.__init_monitor_mode(monitor, mode)
208207
self.__init_ckpt_dir(dirpath, filename, save_top_k)

pytorch_lightning/core/step_result.py

Lines changed: 15 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,9 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
"""[Train, Eval]Result for easier logging, checkpointing, early stopping, epoch-wise reduction."""
14+
"""Result class for easier logging and epoch-wise reduction."""
1515

1616
import numbers
17-
import os
1817
from copy import copy
1918
from typing import Any, Callable, Dict, Iterable, List, MutableMapping, Optional, Sequence, Tuple, Union
2019

@@ -27,33 +26,14 @@
2726

2827
class Result(Dict):
2928

30-
def __init__(
31-
self,
32-
minimize: Optional[Tensor] = None,
33-
early_stop_on: Optional[Tensor] = None,
34-
checkpoint_on: Optional[Union[Tensor, bool]] = None,
35-
hiddens: Optional[Tensor] = None,
36-
):
37-
29+
def __init__(self, minimize: Optional[Tensor] = None):
3830
super().__init__()
3931

40-
# temporary until dict results are deprecated
41-
os.environ['PL_USING_RESULT_OBJ'] = '1'
42-
43-
if early_stop_on is not None:
44-
self.early_stop_on = early_stop_on
45-
if checkpoint_on is not None and checkpoint_on:
46-
self.checkpoint_on = checkpoint_on
47-
if hiddens is not None:
48-
self.hiddens = hiddens.detach()
4932
if minimize is not None:
5033
err = 'Minimize can only be used in training_step, training_step_end, training_epoch_end'
5134
self._assert_grad_tensor_metric('minimize', minimize, err)
5235
self.minimize = minimize
5336

54-
if minimize is not None and checkpoint_on is None:
55-
self.checkpoint_on = minimize.detach()
56-
5737
self['meta'] = {'_internal': {'_reduce_on_epoch': False, 'batch_sizes': []}}
5838

5939
def __getitem__(self, key: Union[str, Any]) -> Any:
@@ -64,9 +44,7 @@ def __getitem__(self, key: Union[str, Any]) -> Any:
6444

6545
def __getattr__(self, key: str) -> Any:
6646
try:
67-
if key == 'callback_metrics':
68-
return self.get_callback_metrics()
69-
elif key == 'batch_log_metrics':
47+
if key == 'batch_log_metrics':
7048
return self.get_batch_log_metrics()
7149
elif key == 'batch_pbar_metrics':
7250
return self.get_batch_pbar_metrics()
@@ -80,16 +58,9 @@ def __getattr__(self, key: str) -> Any:
8058
return None
8159

8260
def __setattr__(self, key: str, val: Union[Tensor, Any]):
83-
# ensure reserve keys are tensors and detached
84-
if key in {'checkpoint_on', 'early_stop_on'}:
85-
self._assert_tensor_metric(key, val)
86-
if val is not None and isinstance(val, torch.Tensor):
87-
val = val.detach()
88-
89-
# ensure anything else that is a tensor is detached
90-
elif isinstance(val, torch.Tensor) and key != 'minimize':
61+
# ensure tensors are detached
62+
if isinstance(val, torch.Tensor) and key != 'minimize':
9163
val = val.detach()
92-
9364
self[key] = val
9465

9566
def __getstate__(self):
@@ -98,11 +69,6 @@ def __getstate__(self):
9869
def __setstate__(self, d):
9970
self.update(d)
10071

101-
def _assert_tensor_metric(self, name: str, potential_metric: Union[bool, Tensor, None, Any]):
102-
if potential_metric is not None and not isinstance(potential_metric, bool):
103-
if not isinstance(potential_metric, Tensor):
104-
raise TypeError(f'{name} must be a torch.Tensor')
105-
10672
def _assert_grad_tensor_metric(self, name: str, x: Union[torch.Tensor, Any], additional_err: str = ''):
10773
if x is not None:
10874
if not isinstance(x, Tensor):
@@ -272,11 +238,6 @@ def get_batch_sizes(self):
272238
meta = self['meta']
273239
return torch.tensor(meta['_internal']['batch_sizes'])
274240

275-
def get_callback_metrics(self) -> dict:
276-
result = {'early_stop_on': self.early_stop_on, 'checkpoint_on': self.checkpoint_on}
277-
278-
return result
279-
280241
def _add_dataloader_idx(self, k: str, dataloader_idx: Union[int, None], add_dataloader_idx: bool) -> str:
281242
if dataloader_idx is not None and add_dataloader_idx:
282243
return f"{k}/dataloader_idx_{dataloader_idx}"
@@ -495,25 +456,22 @@ def padded_gather(cls, outputs):
495456
# find the padding used for other values
496457
default_padding_idx = 0
497458
for name, value in result.items():
498-
if isinstance(value, list) and len(value) > 0 and isinstance(value[0], torch.Tensor):
499-
if name not in {'checkpoint_on', 'early_stop_on', 'minimize'}:
500-
default_padding_idx = meta[name]['tbptt_pad_token']
501-
break
459+
if (
460+
name != 'minimize' and isinstance(value, list) and len(value) > 0
461+
and isinstance(value[0], torch.Tensor)
462+
):
463+
default_padding_idx = meta[name]['tbptt_pad_token']
464+
break
502465

503466
# pad across each key individually
504467
for name, value in result.items():
505-
is_reserved = name in {'checkpoint_on', 'early_stop_on', 'minimize'}
506-
if isinstance(value, list) and len(value) > 0 and isinstance(value[0], torch.Tensor):
507-
508-
if is_reserved:
509-
padding_key = default_padding_idx
510-
else:
511-
padding_key = meta[name]['tbptt_pad_token']
468+
if (isinstance(value, list) and len(value) > 0 and isinstance(value[0], torch.Tensor)):
469+
padding_key = default_padding_idx if name == 'minimize' else meta[name]['tbptt_pad_token']
512470
padded = torch.nn.utils.rnn.pad_sequence(value, batch_first=True, padding_value=padding_key)
513471
result[name] = padded
514472

515473
# also update the result
516-
if meta and not is_reserved:
474+
if meta and name != "minimize":
517475
meta[name]['value'] = padded
518476
if meta:
519477
result['meta'] = meta
@@ -581,10 +539,7 @@ def reduce_across_time(cls, time_outputs):
581539
continue
582540

583541
# pick the reduce fx
584-
if k in ['checkpoint_on', 'early_stop_on', 'minimize']:
585-
tbptt_reduce_fx = torch.mean
586-
else:
587-
tbptt_reduce_fx = meta[k]['tbptt_reduce_fx']
542+
tbptt_reduce_fx = torch.mean if k == "minimize" else meta[k]['tbptt_reduce_fx']
588543

589544
if isinstance(value, list):
590545
value = torch.tensor(value)
@@ -612,10 +567,6 @@ def dp_reduce(self):
612567
def should_reduce_on_epoch_end(self) -> bool:
613568
return self['meta']['_internal']['_reduce_on_epoch']
614569

615-
def drop_hiddens(self):
616-
if 'hiddens' in self:
617-
del self['hiddens']
618-
619570
def rename_keys(self, map_dict: dict):
620571
"""
621572
Maps key values to the target values. Useful when renaming variables in mass.

pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py

Lines changed: 1 addition & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -317,10 +317,6 @@ def _track_callback_metrics(self, eval_results):
317317
elif isinstance(eval_result, dict):
318318
flat = flatten_dict(eval_result)
319319

320-
# removing val_loss magic word to map to checkpoint + ES callback
321-
if 'val_loss' in flat:
322-
flat['checkpoint_on'] = flat['val_loss']
323-
flat['early_stop_on'] = flat['val_loss']
324320
self.trainer.logger_connector.callback_metrics.update(flat)
325321
if self.trainer.state in (TrainerState.TESTING, TrainerState.VALIDATING):
326322
self.trainer.logger_connector.evaluation_callback_metrics.update(flat)
@@ -331,11 +327,6 @@ def _track_callback_metrics(self, eval_results):
331327
else:
332328
flat = flatten_dict(eval_results)
333329

334-
# removing val_loss magic word to map to checkpoint + ES callback
335-
if 'val_loss' in flat:
336-
flat['checkpoint_on'] = flat['val_loss']
337-
flat['early_stop_on'] = flat['val_loss']
338-
339330
self.trainer.logger_connector.callback_metrics.update(flat)
340331
if self.trainer.state in (TrainerState.TESTING, TrainerState.VALIDATING):
341332
self.trainer.logger_connector.evaluation_callback_metrics.update(flat)
@@ -370,26 +361,13 @@ def on_train_epoch_end(self):
370361
# inform cached logger connector epoch finished
371362
self.cached_results.has_batch_loop_finished = True
372363

373-
def log_train_epoch_end_metrics(
374-
self, epoch_output, checkpoint_accumulator, early_stopping_accumulator, num_optimizers
375-
):
364+
def log_train_epoch_end_metrics(self, epoch_output, num_optimizers):
376365
# epoch output is a list. Each item in that list has all the outputs per optimizer
377366
# epoch_output[optimizer_idx][training_step_idx][tbptt_index]
378367
# remember that not using truncated backprop is equivalent with truncated back prop of len(1)
379368

380369
model = self.trainer.lightning_module
381370

382-
epoch_callback_metrics = {}
383-
384-
# -----------------------
385-
# Calculate epoch callback values if given
386-
# -----------------------
387-
if checkpoint_accumulator.num_values > 0:
388-
epoch_callback_metrics['checkpoint_on'] = checkpoint_accumulator.mean()
389-
390-
if early_stopping_accumulator.num_values > 0:
391-
epoch_callback_metrics['early_stop_on'] = early_stopping_accumulator.mean()
392-
393371
# ------------------------
394372
# determine if using a result obj
395373
# ------------------------
@@ -437,9 +415,6 @@ def log_train_epoch_end_metrics(
437415
self.log_metrics(epoch_log_metrics, {})
438416
self._callback_metrics.update(epoch_log_metrics)
439417

440-
# add metrics to callbacks
441-
self._callback_metrics.update(epoch_callback_metrics)
442-
443418
# add metrics to progress_bar and callbacks
444419
if len(epoch_progress_bar_metrics) > 0:
445420
self.add_progress_bar_metrics(epoch_progress_bar_metrics)

pytorch_lightning/trainer/evaluation_loop.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -251,11 +251,6 @@ def __gather_epoch_end_eval_results(self, outputs):
251251
eval_results = []
252252
for epoch_output in outputs:
253253
result = epoch_output[0].__class__.gather(epoch_output)
254-
if 'checkpoint_on' in result:
255-
result.checkpoint_on = result.checkpoint_on.mean()
256-
if 'early_stop_on' in result:
257-
result.early_stop_on = result.early_stop_on.mean()
258-
259254
eval_results.append(result)
260255

261256
# with 1 dataloader don't pass in a list
@@ -269,10 +264,6 @@ def __auto_reduce_result_objs(self, outputs):
269264
for dl_output in outputs:
270265
result = dl_output[0]
271266
result = result.__class__.reduce_on_epoch_end(dl_output)
272-
if 'checkpoint_on' in result:
273-
result.checkpoint_on = result.checkpoint_on.mean()
274-
if 'early_stop_on' in result:
275-
result.early_stop_on = result.early_stop_on.mean()
276267
eval_results.append(result)
277268

278269
return eval_results

pytorch_lightning/trainer/logging.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
import inspect
1616
from abc import ABC
17-
from typing import Mapping
17+
from collections import Mapping
1818

1919
import torch
2020

@@ -76,10 +76,7 @@ def process_dict_result(self, output, train=False):
7676
# --------------------------
7777
# single scalar returned from a xx_step
7878
if isinstance(output, torch.Tensor):
79-
progress_bar_metrics = {}
80-
log_metrics = {}
81-
hiddens = None
82-
return output, progress_bar_metrics, log_metrics, hiddens
79+
return output, {}, {}, None
8380

8481
# ---------------
8582
# EXTRACT PROGRESS BAR KEYS
@@ -140,6 +137,8 @@ def process_dict_result(self, output, train=False):
140137
# EXTRACT HIDDEN
141138
# ---------------
142139
hiddens = output.get('hiddens', None) if isinstance(output, Mapping) else None
140+
if hiddens is not None:
141+
hiddens = hiddens.detach()
143142

144143
# detach all metrics for callbacks to prevent memory leaks
145144
# no .item() because it will slow things down

pytorch_lightning/trainer/supporters.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -104,21 +104,6 @@ def _agg_memory(self, how: str):
104104
return getattr(self.memory[:self.current_idx], how)()
105105

106106

107-
class Accumulator(object):
108-
109-
def __init__(self):
110-
self.num_values = 0
111-
self.total = 0
112-
113-
def accumulate(self, x):
114-
with torch.no_grad():
115-
self.total += x
116-
self.num_values += 1
117-
118-
def mean(self):
119-
return self.total / self.num_values
120-
121-
122107
class PredictionCollection(object):
123108

124109
def __init__(self, global_rank: int, world_size: int):

0 commit comments

Comments
 (0)