diff --git a/CHANGELOG.md b/CHANGELOG.md index 780a8790b9fdd..0713de202b63c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -147,6 +147,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Removed `mode='auto'` from `EarlyStopping` ([#6167](https://github.com/PyTorchLightning/pytorch-lightning/pull/6167)) +- Removed legacy references for magic keys in the `Result` object ([#6016](https://github.com/PyTorchLightning/pytorch-lightning/pull/6016)) + + - Removed deprecated `LightningModule` `hparams` setter ([#6207](https://github.com/PyTorchLightning/pytorch-lightning/pull/6207)) diff --git a/docs/source/common/lightning_module.rst b/docs/source/common/lightning_module.rst index 6e67f591da7c7..d4a6e3ae94dfc 100644 --- a/docs/source/common/lightning_module.rst +++ b/docs/source/common/lightning_module.rst @@ -602,7 +602,6 @@ For cases like production, you might want to iterate different models inside a L loss = F.cross_entropy(y_hat, y) acc = FM.accuracy(y_hat, y) - # loss is tensor. The Checkpoint Callback is monitoring 'checkpoint_on' metrics = {'val_acc': acc, 'val_loss': loss} self.log_dict(metrics) return metrics diff --git a/docs/source/common/trainer.rst b/docs/source/common/trainer.rst index d86a8dc1ff472..96e19a7be4694 100644 --- a/docs/source/common/trainer.rst +++ b/docs/source/common/trainer.rst @@ -1478,15 +1478,9 @@ with the hidden def training_step(self, batch, batch_idx, hiddens): # hiddens are the hiddens from the previous truncated backprop step out, hiddens = self.lstm(data, hiddens) - - # remember to detach() hiddens. - # If you don't, you will get a RuntimeError: Trying to backward through - # the graph a second time... - # Using hiddens.detach() allows each split to be disconnected. - return { "loss": ..., - "hiddens": hiddens # remember to detach() this + "hiddens": hiddens } To modify how the batch is split, diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index 4448de8e4834b..24ebcdf807357 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -90,7 +90,6 @@ def __init__( self.wait_count = 0 self.stopped_epoch = 0 self.mode = mode - self.warned_result_obj = False if self.mode not in self.mode_dict: raise MisconfigurationException(f"`mode` can be {', '.join(self.mode_dict.keys())}, got {self.mode}") diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index d9dea5979ae58..5f0318e7ac8d1 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -202,7 +202,6 @@ def __init__( self.best_model_path = "" self.last_model_path = "" self.save_function = None - self.warned_result_obj = False self.__init_monitor_mode(monitor, mode) self.__init_ckpt_dir(dirpath, filename, save_top_k) diff --git a/pytorch_lightning/core/step_result.py b/pytorch_lightning/core/step_result.py index 3961586f4946a..d61c9fb5d3d1e 100644 --- a/pytorch_lightning/core/step_result.py +++ b/pytorch_lightning/core/step_result.py @@ -11,10 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""[Train, Eval]Result for easier logging, checkpointing, early stopping, epoch-wise reduction.""" +"""Result class for easier logging and epoch-wise reduction.""" import numbers -import os from copy import copy from typing import Any, Callable, Dict, Iterable, List, MutableMapping, Optional, Sequence, Tuple, Union @@ -27,33 +26,14 @@ class Result(Dict): - def __init__( - self, - minimize: Optional[Tensor] = None, - early_stop_on: Optional[Tensor] = None, - checkpoint_on: Optional[Union[Tensor, bool]] = None, - hiddens: Optional[Tensor] = None, - ): - + def __init__(self, minimize: Optional[Tensor] = None): super().__init__() - # temporary until dict results are deprecated - os.environ['PL_USING_RESULT_OBJ'] = '1' - - if early_stop_on is not None: - self.early_stop_on = early_stop_on - if checkpoint_on is not None and checkpoint_on: - self.checkpoint_on = checkpoint_on - if hiddens is not None: - self.hiddens = hiddens.detach() if minimize is not None: err = 'Minimize can only be used in training_step, training_step_end, training_epoch_end' self._assert_grad_tensor_metric('minimize', minimize, err) self.minimize = minimize - if minimize is not None and checkpoint_on is None: - self.checkpoint_on = minimize.detach() - self['meta'] = {'_internal': {'_reduce_on_epoch': False, 'batch_sizes': []}} def __getitem__(self, key: Union[str, Any]) -> Any: @@ -64,9 +44,7 @@ def __getitem__(self, key: Union[str, Any]) -> Any: def __getattr__(self, key: str) -> Any: try: - if key == 'callback_metrics': - return self.get_callback_metrics() - elif key == 'batch_log_metrics': + if key == 'batch_log_metrics': return self.get_batch_log_metrics() elif key == 'batch_pbar_metrics': return self.get_batch_pbar_metrics() @@ -80,16 +58,9 @@ def __getattr__(self, key: str) -> Any: return None def __setattr__(self, key: str, val: Union[Tensor, Any]): - # ensure reserve keys are tensors and detached - if key in {'checkpoint_on', 'early_stop_on'}: - self._assert_tensor_metric(key, val) - if val is not None and isinstance(val, torch.Tensor): - val = val.detach() - - # ensure anything else that is a tensor is detached - elif isinstance(val, torch.Tensor) and key != 'minimize': + # ensure tensors are detached + if isinstance(val, torch.Tensor) and key != 'minimize': val = val.detach() - self[key] = val def __getstate__(self): @@ -98,11 +69,6 @@ def __getstate__(self): def __setstate__(self, d): self.update(d) - def _assert_tensor_metric(self, name: str, potential_metric: Union[bool, Tensor, None, Any]): - if potential_metric is not None and not isinstance(potential_metric, bool): - if not isinstance(potential_metric, Tensor): - raise TypeError(f'{name} must be a torch.Tensor') - def _assert_grad_tensor_metric(self, name: str, x: Union[torch.Tensor, Any], additional_err: str = ''): if x is not None: if not isinstance(x, Tensor): @@ -272,11 +238,6 @@ def get_batch_sizes(self): meta = self['meta'] return torch.tensor(meta['_internal']['batch_sizes']) - def get_callback_metrics(self) -> dict: - result = {'early_stop_on': self.early_stop_on, 'checkpoint_on': self.checkpoint_on} - - return result - def _add_dataloader_idx(self, k: str, dataloader_idx: Union[int, None], add_dataloader_idx: bool) -> str: if dataloader_idx is not None and add_dataloader_idx: return f"{k}/dataloader_idx_{dataloader_idx}" @@ -495,25 +456,22 @@ def padded_gather(cls, outputs): # find the padding used for other values default_padding_idx = 0 for name, value in result.items(): - if isinstance(value, list) and len(value) > 0 and isinstance(value[0], torch.Tensor): - if name not in {'checkpoint_on', 'early_stop_on', 'minimize'}: - default_padding_idx = meta[name]['tbptt_pad_token'] - break + if ( + name != 'minimize' and isinstance(value, list) and len(value) > 0 + and isinstance(value[0], torch.Tensor) + ): + default_padding_idx = meta[name]['tbptt_pad_token'] + break # pad across each key individually for name, value in result.items(): - is_reserved = name in {'checkpoint_on', 'early_stop_on', 'minimize'} - if isinstance(value, list) and len(value) > 0 and isinstance(value[0], torch.Tensor): - - if is_reserved: - padding_key = default_padding_idx - else: - padding_key = meta[name]['tbptt_pad_token'] + if (isinstance(value, list) and len(value) > 0 and isinstance(value[0], torch.Tensor)): + padding_key = default_padding_idx if name == 'minimize' else meta[name]['tbptt_pad_token'] padded = torch.nn.utils.rnn.pad_sequence(value, batch_first=True, padding_value=padding_key) result[name] = padded # also update the result - if meta and not is_reserved: + if meta and name != "minimize": meta[name]['value'] = padded if meta: result['meta'] = meta @@ -581,10 +539,7 @@ def reduce_across_time(cls, time_outputs): continue # pick the reduce fx - if k in ['checkpoint_on', 'early_stop_on', 'minimize']: - tbptt_reduce_fx = torch.mean - else: - tbptt_reduce_fx = meta[k]['tbptt_reduce_fx'] + tbptt_reduce_fx = torch.mean if k == "minimize" else meta[k]['tbptt_reduce_fx'] if isinstance(value, list): value = torch.tensor(value) @@ -612,10 +567,6 @@ def dp_reduce(self): def should_reduce_on_epoch_end(self) -> bool: return self['meta']['_internal']['_reduce_on_epoch'] - def drop_hiddens(self): - if 'hiddens' in self: - del self['hiddens'] - def rename_keys(self, map_dict: dict): """ Maps key values to the target values. Useful when renaming variables in mass. diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index 0bae4effbb383..6a9d75638bbe0 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -317,10 +317,6 @@ def _track_callback_metrics(self, eval_results): elif isinstance(eval_result, dict): flat = flatten_dict(eval_result) - # removing val_loss magic word to map to checkpoint + ES callback - if 'val_loss' in flat: - flat['checkpoint_on'] = flat['val_loss'] - flat['early_stop_on'] = flat['val_loss'] self.trainer.logger_connector.callback_metrics.update(flat) if self.trainer.state in (TrainerState.TESTING, TrainerState.VALIDATING): self.trainer.logger_connector.evaluation_callback_metrics.update(flat) @@ -331,11 +327,6 @@ def _track_callback_metrics(self, eval_results): else: flat = flatten_dict(eval_results) - # removing val_loss magic word to map to checkpoint + ES callback - if 'val_loss' in flat: - flat['checkpoint_on'] = flat['val_loss'] - flat['early_stop_on'] = flat['val_loss'] - self.trainer.logger_connector.callback_metrics.update(flat) if self.trainer.state in (TrainerState.TESTING, TrainerState.VALIDATING): self.trainer.logger_connector.evaluation_callback_metrics.update(flat) @@ -370,26 +361,13 @@ def on_train_epoch_end(self): # inform cached logger connector epoch finished self.cached_results.has_batch_loop_finished = True - def log_train_epoch_end_metrics( - self, epoch_output, checkpoint_accumulator, early_stopping_accumulator, num_optimizers - ): + def log_train_epoch_end_metrics(self, epoch_output, num_optimizers): # epoch output is a list. Each item in that list has all the outputs per optimizer # epoch_output[optimizer_idx][training_step_idx][tbptt_index] # remember that not using truncated backprop is equivalent with truncated back prop of len(1) model = self.trainer.lightning_module - epoch_callback_metrics = {} - - # ----------------------- - # Calculate epoch callback values if given - # ----------------------- - if checkpoint_accumulator.num_values > 0: - epoch_callback_metrics['checkpoint_on'] = checkpoint_accumulator.mean() - - if early_stopping_accumulator.num_values > 0: - epoch_callback_metrics['early_stop_on'] = early_stopping_accumulator.mean() - # ------------------------ # determine if using a result obj # ------------------------ @@ -437,9 +415,6 @@ def log_train_epoch_end_metrics( self.log_metrics(epoch_log_metrics, {}) self._callback_metrics.update(epoch_log_metrics) - # add metrics to callbacks - self._callback_metrics.update(epoch_callback_metrics) - # add metrics to progress_bar and callbacks if len(epoch_progress_bar_metrics) > 0: self.add_progress_bar_metrics(epoch_progress_bar_metrics) diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index a87073428e725..a1b66cb561889 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -251,11 +251,6 @@ def __gather_epoch_end_eval_results(self, outputs): eval_results = [] for epoch_output in outputs: result = epoch_output[0].__class__.gather(epoch_output) - if 'checkpoint_on' in result: - result.checkpoint_on = result.checkpoint_on.mean() - if 'early_stop_on' in result: - result.early_stop_on = result.early_stop_on.mean() - eval_results.append(result) # with 1 dataloader don't pass in a list @@ -269,10 +264,6 @@ def __auto_reduce_result_objs(self, outputs): for dl_output in outputs: result = dl_output[0] result = result.__class__.reduce_on_epoch_end(dl_output) - if 'checkpoint_on' in result: - result.checkpoint_on = result.checkpoint_on.mean() - if 'early_stop_on' in result: - result.early_stop_on = result.early_stop_on.mean() eval_results.append(result) return eval_results diff --git a/pytorch_lightning/trainer/logging.py b/pytorch_lightning/trainer/logging.py index 1eb77d517d952..8aaac0a659152 100644 --- a/pytorch_lightning/trainer/logging.py +++ b/pytorch_lightning/trainer/logging.py @@ -14,7 +14,7 @@ import inspect from abc import ABC -from typing import Mapping +from collections import Mapping import torch @@ -76,10 +76,7 @@ def process_dict_result(self, output, train=False): # -------------------------- # single scalar returned from a xx_step if isinstance(output, torch.Tensor): - progress_bar_metrics = {} - log_metrics = {} - hiddens = None - return output, progress_bar_metrics, log_metrics, hiddens + return output, {}, {}, None # --------------- # EXTRACT PROGRESS BAR KEYS @@ -140,6 +137,8 @@ def process_dict_result(self, output, train=False): # EXTRACT HIDDEN # --------------- hiddens = output.get('hiddens', None) if isinstance(output, Mapping) else None + if hiddens is not None: + hiddens = hiddens.detach() # detach all metrics for callbacks to prevent memory leaks # no .item() because it will slow things down diff --git a/pytorch_lightning/trainer/supporters.py b/pytorch_lightning/trainer/supporters.py index aff458d1b6084..f884306dc09c8 100644 --- a/pytorch_lightning/trainer/supporters.py +++ b/pytorch_lightning/trainer/supporters.py @@ -104,21 +104,6 @@ def _agg_memory(self, how: str): return getattr(self.memory[:self.current_idx], how)() -class Accumulator(object): - - def __init__(self): - self.num_values = 0 - self.total = 0 - - def accumulate(self, x): - with torch.no_grad(): - self.total += x - self.num_values += 1 - - def mean(self): - return self.total / self.num_values - - class PredictionCollection(object): def __init__(self, global_rank: int, world_size: int): diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index a77524b8a55e7..696f14742935c 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -24,7 +24,7 @@ from pytorch_lightning.core.step_result import Result from pytorch_lightning.plugins import ParallelPlugin from pytorch_lightning.trainer.states import TrainerState -from pytorch_lightning.trainer.supporters import Accumulator, TensorRunningAccum +from pytorch_lightning.trainer.supporters import TensorRunningAccum from pytorch_lightning.utilities import _TPU_AVAILABLE, AMPType, DeviceType, parsing from pytorch_lightning.utilities.distributed import rank_zero_info from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -38,8 +38,6 @@ class TrainLoop: def __init__(self, trainer, multiple_trainloader_mode: str): self.trainer = trainer - self.early_stopping_accumulator = None - self.checkpoint_accumulator = None self.accumulated_loss = None self.warning_cache = WarningCache() self._teardown_already_run = False @@ -182,10 +180,6 @@ def on_train_epoch_start(self, epoch): # stores accumulated grad fractions per batch self.accumulated_loss = TensorRunningAccum(window_length=self.trainer.accumulate_grad_batches) - # structured result accumulators for callbacks - self.early_stopping_accumulator = Accumulator() - self.checkpoint_accumulator = Accumulator() - # hook self.trainer.call_hook("on_epoch_start") self.trainer.call_hook("on_train_epoch_start") @@ -318,7 +312,6 @@ def training_step(self, split_batch, batch_idx, opt_idx, hiddens): loss=untouched_loss, training_step_output=training_step_output, training_step_output_for_epoch_end=training_step_output_for_epoch_end, - hiddens=training_step_output.hiddens, ) return result @@ -348,7 +341,6 @@ def _process_training_step_output(self, training_step_output, split_batch): batch_loss=training_step_output[0], pbar_on_batch_end=training_step_output[1], log_metrics=training_step_output[2], - hiddens=training_step_output[3], ) # if the user decides to finally reduce things in epoch_end, save raw output without graphs if isinstance(training_step_output_for_epoch_end, torch.Tensor): @@ -363,6 +355,7 @@ def _process_training_step_output_1_0(self, training_step_output, split_batch): loss = None hiddens = None + result["extra"] = {} # handle dict return if isinstance(training_step_output, dict): @@ -373,11 +366,10 @@ def _process_training_step_output_1_0(self, training_step_output, split_batch): # handle scalar return elif isinstance(training_step_output, torch.Tensor): loss = training_step_output - result["extra"] = {} # map to results under the hood result.minimize = loss - result.hiddens = hiddens + self.trainer.hiddens = hiddens # track batch for manual reduction with result result.track_batch_size(len(split_batch)) @@ -443,12 +435,6 @@ def _track_gradient_norm(self): grad_norm_dict = model.grad_norm(self.trainer.track_grad_norm) return grad_norm_dict - def process_hiddens(self, opt_closure_result): - hiddens = opt_closure_result.hiddens - if isinstance(opt_closure_result.training_step_output, Result): - opt_closure_result.training_step_output_for_epoch_end.drop_hiddens() - return hiddens - def tbptt_split_batch(self, batch): splits = [batch] if self.trainer.truncated_bptt_steps is not None: @@ -482,11 +468,7 @@ def run_training_epoch(self): if batch_output.signal == -1: break - batch_end_outputs = self.process_train_step_outputs( - batch_output.training_step_output_for_epoch_end, - self.early_stopping_accumulator, - self.checkpoint_accumulator, - ) + batch_end_outputs = self.process_train_step_outputs(batch_output.training_step_output_for_epoch_end) # hook # TODO: add outputs to batches self.on_train_batch_end(epoch_output, batch_end_outputs, batch, batch_idx, dataloader_idx) @@ -542,9 +524,7 @@ def run_training_epoch(self): self.on_train_epoch_end(epoch_output) # log epoch metrics - self.trainer.logger_connector.log_train_epoch_end_metrics( - epoch_output, self.checkpoint_accumulator, self.early_stopping_accumulator, self.num_optimizers - ) + self.trainer.logger_connector.log_train_epoch_end_metrics(epoch_output, self.num_optimizers) should_check_val = self.should_check_val_fx(batch_idx, is_last_batch, on_epoch=True) should_skip_eval = self.trainer.evaluation_loop.should_skip_evaluation(self.trainer.num_val_batches) @@ -698,9 +678,6 @@ def _process_closure_result(self, batch_outputs: list, opt_idx: int) -> list: # cache metrics self.trainer.logger_connector.cache_training_step_metrics(opt_closure_result) - # track hiddens - self.trainer.hiddens = self.process_hiddens(opt_closure_result) - # check if loss or model weights are nan if self.trainer.terminate_on_nan: self.trainer.detect_nan_tensors(opt_closure_result.loss) @@ -853,32 +830,15 @@ def save_loggers_on_train_batch_end(self): if should_flush_logs and self.trainer.is_global_zero and self.trainer.logger is not None: self.trainer.logger.save() - def process_train_step_outputs(self, all_train_step_outputs, early_stopping_accumulator, checkpoint_accumulator): + def process_train_step_outputs(self, all_train_step_outputs): """ Figure out what needs to be tracked/logged at the end of the epoch """ - # the training step outputs a list per optimizer. The list contains the outputs at each time step # when no TBPTT is used, then the list has 1 item per batch # when TBPTT IS used, then the list has n items (1 per time step) - batch_end_outputs = [] - for optimizer_idx_outputs in all_train_step_outputs: - # extract one representative sample from each time step (1 if no tbptt) and 0th optimizer - if len(optimizer_idx_outputs) == 0: - continue - - sample_output = optimizer_idx_outputs[-1] - - # pull out callback info if available (ie: Results object) - if isinstance(sample_output, dict) and "early_stop_on" in sample_output: - early_stopping_accumulator.accumulate(sample_output["early_stop_on"]) - - if isinstance(sample_output, dict) and "checkpoint_on" in sample_output: - checkpoint_accumulator.accumulate(sample_output["checkpoint_on"]) - - batch_end_outputs.append(optimizer_idx_outputs) - - return batch_end_outputs + # extract one representative sample from each time step (1 if no tbptt) and 0th optimizer + return [opt_idx_out for opt_idx_out in all_train_step_outputs if len(opt_idx_out)] def prepare_optimizers(self): # in manual optimization we loop over all optimizers at once diff --git a/tests/trainer/logging_/test_logger_connector.py b/tests/trainer/logging_/test_logger_connector.py index f22fb30e94c76..25103559cd070 100644 --- a/tests/trainer/logging_/test_logger_connector.py +++ b/tests/trainer/logging_/test_logger_connector.py @@ -111,7 +111,7 @@ def training_step_end(self, *_): assert generated == excepted -def test__logger_connector__epoch_result_store__train__ttbt(tmpdir): +def test__logger_connector__epoch_result_store__train__tbptt(tmpdir): """ Tests that LoggerConnector will properly capture logged information with ttbt and reduce them @@ -142,6 +142,7 @@ def __init__(self): @decorator_with_arguments(fx_name="training_step") def training_step(self, batch, batch_idx, hiddens): + assert hiddens == self.test_hidden, "Hidden state not persistent between tbptt steps" self.test_hidden = torch.rand(1) x_tensor, y_list = batch diff --git a/tests/trainer/logging_/test_train_loop_logging_1_0.py b/tests/trainer/logging_/test_train_loop_logging_1_0.py index f8672eb4ec51e..092751ec68e33 100644 --- a/tests/trainer/logging_/test_train_loop_logging_1_0.py +++ b/tests/trainer/logging_/test_train_loop_logging_1_0.py @@ -318,12 +318,7 @@ def __init__(self): self.layer = torch.nn.Linear(2, 2) def training_step(self, batch, batch_idx, hiddens): - try: - assert hiddens == self.test_hidden, "Hidden state not persistent between tbptt steps" - # todo: specify the possible exception - except Exception as ex: - print(ex) - + assert hiddens == self.test_hidden, "Hidden state not persistent between tbptt steps" self.test_hidden = torch.rand(1) x_tensor, y_list = batch