diff --git a/CHANGELOG.md b/CHANGELOG.md index 020d354f6d0c8..40a7cf54676b5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -85,6 +85,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fault-tolerant training * Add `{,load_}state_dict` to `ResultCollection` ([#7948](https://github.com/PyTorchLightning/pytorch-lightning/pull/7948)) + * Checkpoint the loop results ([#7966](https://github.com/PyTorchLightning/pytorch-lightning/pull/7966)) + + +- Add `rank_zero_only` to `LightningModule.log` function ([#7966](https://github.com/PyTorchLightning/pytorch-lightning/pull/7966)) - Added a warning if `Trainer(log_every_n_steps)` is a value too high for the training dataloader ([#7734](https://github.com/PyTorchLightning/pytorch-lightning/pull/7734)) diff --git a/docs/source/advanced/multi_gpu.rst b/docs/source/advanced/multi_gpu.rst index 1c465ae314e4f..699be201f95b8 100644 --- a/docs/source/advanced/multi_gpu.rst +++ b/docs/source/advanced/multi_gpu.rst @@ -106,6 +106,23 @@ Note if you use any built in metrics or custom metrics that use the :doc:`Metric # Add sync_dist=True to sync logging across all GPU workers self.log('test_loss', loss, on_step=True, on_epoch=True, sync_dist=True) +It is possible to perform some computation manually and log the reduced result on rank 0 as follows: + +.. testcode:: + + def test_step(self, batch, batch_idx): + x, y = batch + tensors = self(x) + return tensors + + def test_epoch_end(self, outputs): + mean = torch.mean(self.all_gather(outputs)) + + # When logging only on rank 0, don't forget to add + # ``rank_zero_only=True`` to avoid deadlocks on synchronization. + if self.trainer.is_global_zero: + self.log("my_reduced_metric", mean, rank_zero_only=True) + Make models pickleable ^^^^^^^^^^^^^^^^^^^^^^ diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index e7c9852968b36..bf05b1f0772f0 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -39,11 +39,12 @@ from pytorch_lightning.core.memory import ModelSummary from pytorch_lightning.core.optimizer import LightningOptimizer from pytorch_lightning.core.saving import ALLOWED_CONFIG_TYPES, ModelIO, PRIMITIVE_TYPES +from pytorch_lightning.trainer.connectors.logger_connector.fx_validator import FxValidator from pytorch_lightning.utilities import rank_zero_deprecation, rank_zero_warn from pytorch_lightning.utilities.apply_func import apply_to_collection, convert_to_tensors from pytorch_lightning.utilities.cloud_io import get_filesystem from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin -from pytorch_lightning.utilities.distributed import sync_ddp_if_available +from pytorch_lightning.utilities.distributed import distributed_available, sync_ddp from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.parsing import AttributeDict, collect_init_args, save_hyperparameters from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature @@ -112,6 +113,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: self._automatic_optimization: bool = True self._truncated_bptt_steps: int = 0 self._param_requires_grad_state = dict() + self._metric_attributes: Optional[Dict[int, str]] = None def optimizers(self, use_pl_optimizer: bool = True) -> Union[Optimizer, List[Optimizer], List[LightningOptimizer]]: if use_pl_optimizer: @@ -273,6 +275,8 @@ def log( sync_dist_group: Optional[Any] = None, add_dataloader_idx: bool = True, batch_size: Optional[int] = None, + metric_attribute: Optional[str] = None, + rank_zero_only: Optional[bool] = None, ) -> None: """ Log a key, value @@ -310,6 +314,10 @@ def log( each dataloader to not mix values batch_size: Current batch_size. This will be directly inferred from the loaded batch, but some data structures might need to explicitly provide it. + metric_attribute: To restore the metric state, Lightning requires the reference of the + :class:`torchmetrics.Metric` in your model. This is found automatically if it is a model attribute. + rank_zero_only: Whether the value will be logged only on rank 0. This will prevent synchronization which + would produce a deadlock as not all processes would perform this log call. """ if tbptt_reduce_fx is not None: rank_zero_deprecation( @@ -346,7 +354,7 @@ def log( results = self.trainer._results assert results is not None assert self._current_fx_name is not None - results.fx_validator.check_logging(self._current_fx_name, on_step=on_step, on_epoch=on_epoch) + FxValidator.check_logging(self._current_fx_name, on_step=on_step, on_epoch=on_epoch) # make sure user doesn't introduce logic for multi-dataloaders if "/dataloader_idx_" in name: @@ -362,6 +370,27 @@ def log( # reset any tensors for the new hook name results.reset(metrics=False, fx=self._current_fx_name) + if metric_attribute is None and isinstance(value, Metric): + if self._metric_attributes is None: + # compute once + self._metric_attributes = { + id(module): name + for name, module in self.named_children() if isinstance(module, Metric) + } + if not self._metric_attributes: + raise MisconfigurationException( + "Could not find the `LightningModule` attribute for the `torchmetrics.Metric` logged." + " You can fix this by setting an attribute for the metric in your `LightningModule`." + ) + # try to find the passed metric in the LightningModule + metric_attribute = self._metric_attributes.get(id(value)) + if metric_attribute is None: + raise MisconfigurationException( + "Could not find the `LightningModule` attribute for the `torchmetrics.Metric` logged." + f" You can fix this by calling `self.log({name}, ..., metric_attribute=name)` where `name` is one" + f" of {list(self._metric_attributes.values())}" + ) + results.log( self._current_fx_name, name, @@ -374,9 +403,11 @@ def log( enable_graph=enable_graph, dataloader_idx=(self._current_dataloader_idx if add_dataloader_idx else None), batch_size=batch_size, - sync_dist=sync_dist, - sync_dist_fn=self.trainer.training_type_plugin.reduce or sync_ddp_if_available, + sync_dist=sync_dist and distributed_available(), + sync_dist_fn=self.trainer.training_type_plugin.reduce or sync_ddp, sync_dist_group=sync_dist_group, + metric_attribute=metric_attribute, + rank_zero_only=rank_zero_only, ) self.trainer.logger_connector._current_fx = self._current_fx_name diff --git a/pytorch_lightning/loggers/base.py b/pytorch_lightning/loggers/base.py index 7736ed24baefe..803d08eb3e645 100644 --- a/pytorch_lightning/loggers/base.py +++ b/pytorch_lightning/loggers/base.py @@ -25,8 +25,8 @@ import numpy as np import torch +import pytorch_lightning as pl from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint -from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.utilities import rank_zero_only @@ -300,7 +300,7 @@ def log_hyperparams(self, params: argparse.Namespace, *args, **kwargs): kwargs: Optional keywoard arguments, depends on the specific logger being used """ - def log_graph(self, model: LightningModule, input_array=None) -> None: + def log_graph(self, model: 'pl.LightningModule', input_array=None) -> None: """ Record model graph @@ -396,7 +396,7 @@ def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None: for logger in self._logger_iterable: logger.log_hyperparams(params) - def log_graph(self, model: LightningModule, input_array=None) -> None: + def log_graph(self, model: 'pl.LightningModule', input_array=None) -> None: for logger in self._logger_iterable: logger.log_graph(model, input_array) diff --git a/pytorch_lightning/loggers/comet.py b/pytorch_lightning/loggers/comet.py index 148e512f5e439..498a16a9daa29 100644 --- a/pytorch_lightning/loggers/comet.py +++ b/pytorch_lightning/loggers/comet.py @@ -24,7 +24,7 @@ import torch from torch import is_tensor -from pytorch_lightning.core.lightning import LightningModule +import pytorch_lightning as pl from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment from pytorch_lightning.utilities import _module_available, rank_zero_only from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -318,6 +318,6 @@ def __getstate__(self): state["_experiment"] = None return state - def log_graph(self, model: LightningModule, input_array=None) -> None: + def log_graph(self, model: 'pl.LightningModule', input_array=None) -> None: if self._experiment is not None: self._experiment.set_model_graph(model) diff --git a/pytorch_lightning/loggers/tensorboard.py b/pytorch_lightning/loggers/tensorboard.py index b69f31ae53b32..d59830bd98ae4 100644 --- a/pytorch_lightning/loggers/tensorboard.py +++ b/pytorch_lightning/loggers/tensorboard.py @@ -25,7 +25,7 @@ from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard.summary import hparams -from pytorch_lightning.core.lightning import LightningModule +import pytorch_lightning as pl from pytorch_lightning.core.saving import save_hparams_to_yaml from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, rank_zero_only, rank_zero_warn @@ -223,7 +223,7 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> raise ValueError(m) from ex @rank_zero_only - def log_graph(self, model: LightningModule, input_array=None): + def log_graph(self, model: 'pl.LightningModule', input_array=None): if self._log_graph: if input_array is None: input_array = model.example_input_array diff --git a/pytorch_lightning/loggers/test_tube.py b/pytorch_lightning/loggers/test_tube.py index 1107a0bcb2c4c..1650ab8f4ba49 100644 --- a/pytorch_lightning/loggers/test_tube.py +++ b/pytorch_lightning/loggers/test_tube.py @@ -18,7 +18,7 @@ from argparse import Namespace from typing import Any, Dict, Optional, Union -from pytorch_lightning.core.lightning import LightningModule +import pytorch_lightning as pl from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment from pytorch_lightning.utilities import _module_available, rank_zero_warn from pytorch_lightning.utilities.distributed import rank_zero_only @@ -153,7 +153,7 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> self.experiment.log(metrics, global_step=step) @rank_zero_only - def log_graph(self, model: LightningModule, input_array=None): + def log_graph(self, model: 'pl.LightningModule', input_array=None): if self._log_graph: if input_array is None: input_array = model.example_input_array diff --git a/pytorch_lightning/overrides/fairscale.py b/pytorch_lightning/overrides/fairscale.py index f7c3b8d5fd575..e531db6de77f3 100644 --- a/pytorch_lightning/overrides/fairscale.py +++ b/pytorch_lightning/overrides/fairscale.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from pytorch_lightning.core.lightning import LightningModule +import pytorch_lightning as pl from pytorch_lightning.overrides.base import _LightningModuleWrapperBase, unwrap_lightning_module from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE @@ -23,7 +23,7 @@ class LightningShardedDataParallel(_LightningModuleWrapperBase): # Just do this for later docstrings pass - def unwrap_lightning_module_sharded(wrapped_model) -> LightningModule: + def unwrap_lightning_module_sharded(wrapped_model) -> 'pl.LightningModule': model = wrapped_model if isinstance(model, ShardedDataParallel): model = model.module diff --git a/pytorch_lightning/plugins/precision/apex_amp.py b/pytorch_lightning/plugins/precision/apex_amp.py index 21253ea9ab4a0..b2565e7dd34b4 100644 --- a/pytorch_lightning/plugins/precision/apex_amp.py +++ b/pytorch_lightning/plugins/precision/apex_amp.py @@ -18,7 +18,6 @@ from torch.optim import Optimizer import pytorch_lightning as pl -from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.plugins.precision.mixed import MixedPrecisionPlugin from pytorch_lightning.utilities import _APEX_AVAILABLE, AMPType from pytorch_lightning.utilities.types import _PARAMETERS @@ -50,7 +49,7 @@ def dispatch(self, trainer: 'pl.Trainer') -> None: def backward( self, - model: LightningModule, + model: 'pl.LightningModule', closure_loss: Tensor, optimizer: Optimizer, opt_idx: int, @@ -76,7 +75,7 @@ def backward( # do backward pass # TODO: not entirely sure, why we need this - if model is not None and isinstance(model, LightningModule): + if model is not None and isinstance(model, pl.LightningModule): model.backward(closure_loss, optimizer, opt_idx, **kwargs) # TODO: avoid dev_debugger and track these calls with mock @@ -118,7 +117,7 @@ def reinit_scheduler_properties(optimizers: Sequence[Optimizer], schedulers: Seq def pre_optimizer_step( self, - pl_module: LightningModule, + pl_module: 'pl.LightningModule', optimizer: Optimizer, optimizer_idx: int, lambda_closure: Callable, diff --git a/pytorch_lightning/plugins/precision/double.py b/pytorch_lightning/plugins/precision/double.py index e0ecddf322250..86177c5500e2f 100644 --- a/pytorch_lightning/plugins/precision/double.py +++ b/pytorch_lightning/plugins/precision/double.py @@ -18,7 +18,7 @@ import torch.nn as nn from torch.optim import Optimizer -from pytorch_lightning.core.lightning import LightningModule +import pytorch_lightning as pl from pytorch_lightning.overrides.base import _LightningPrecisionModuleWrapperBase from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin from pytorch_lightning.utilities.apply_func import apply_to_collection @@ -33,7 +33,7 @@ class LightningDoublePrecisionModule(_LightningPrecisionModuleWrapperBase): pl_module: the model to wrap """ - def __init__(self, pl_module: LightningModule): + def __init__(self, pl_module: 'pl.LightningModule'): super().__init__(pl_module) @staticmethod @@ -96,7 +96,7 @@ def connect( incoming floating point data to double (``torch.float64``) precision. Does not alter `optimizers` or `lr_schedulers`. """ - model = cast(LightningModule, model.to(dtype=torch.float64)) + model = cast(pl.LightningModule, model.to(dtype=torch.float64)) model = LightningDoublePrecisionModule(model) return super().connect(model, optimizers, lr_schedulers) diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index 47f2a64c04759..b71fc10609cdc 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -276,6 +276,9 @@ def transfer_distrib_spawn_state_on_fit_end(self, results): checkpoint_callback = self.lightning_module.trainer.checkpoint_callback best_model_path = checkpoint_callback.best_model_path if checkpoint_callback else None + # requires to compute the state_dict on all processes in case Metrics are present + state_dict = self.lightning_module.state_dict() + if self.global_rank == 0 and self.mp_queue is not None: rank_zero_warn("cleaning up ddp environment...") @@ -286,7 +289,7 @@ def transfer_distrib_spawn_state_on_fit_end(self, results): and len(best_model_path) > 0 ): last_path = re.sub(".ckpt", ".tmp_end.ckpt", best_model_path) - atomic_save(self.on_save(self.lightning_module.state_dict()), last_path) + atomic_save(self.on_save(state_dict), last_path) # todo, pass complete checkpoint as state dictionary self.mp_queue.put(best_model_path) diff --git a/pytorch_lightning/plugins/training_type/ipu.py b/pytorch_lightning/plugins/training_type/ipu.py index 4e75358b67fae..b3a22ad1ad3b2 100644 --- a/pytorch_lightning/plugins/training_type/ipu.py +++ b/pytorch_lightning/plugins/training_type/ipu.py @@ -19,8 +19,8 @@ import torch from torch.utils.data import DataLoader +import pytorch_lightning as pl from pytorch_lightning.callbacks import GradientAccumulationScheduler -from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.overrides.base import _LightningModuleWrapperBase from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin @@ -37,7 +37,7 @@ class LightningIPUModule(_LightningModuleWrapperBase): - def __init__(self, pl_module: LightningModule, precision: Union[str, int]): + def __init__(self, pl_module: 'pl.LightningModule', precision: Union[str, int]): super().__init__(pl_module) self.precision = precision @@ -184,7 +184,7 @@ def _validate_opts(self, opts: 'poptorch.Options', training: bool) -> None: opts.Training.set(gradient_accumulation=1) @property - def lightning_module(self) -> Optional[LightningModule]: + def lightning_module(self) -> Optional['pl.LightningModule']: return self.model.module if isinstance(self.model, LightningIPUModule) else self.model def on_reset_train_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[Iterable, DataLoader]: diff --git a/pytorch_lightning/plugins/training_type/sharded.py b/pytorch_lightning/plugins/training_type/sharded.py index fceafddd66ec0..7e5796d5b5668 100644 --- a/pytorch_lightning/plugins/training_type/sharded.py +++ b/pytorch_lightning/plugins/training_type/sharded.py @@ -16,7 +16,7 @@ import torch from torch.optim import Optimizer -from pytorch_lightning.core.lightning import LightningModule +import pytorch_lightning as pl from pytorch_lightning.core.optimizer import is_lightning_optimizer from pytorch_lightning.plugins.training_type.ddp import DDPPlugin from pytorch_lightning.trainer.states import TrainerFn @@ -86,7 +86,7 @@ def _optim_state_dict(self, optimizer): return optimizer.state_dict() @property - def lightning_module(self) -> LightningModule: + def lightning_module(self) -> 'pl.LightningModule': if not _FAIRSCALE_AVAILABLE: # pragma: no cover raise MisconfigurationException( "`DDPShardedPlugin` requires `fairscale` to be installed." diff --git a/pytorch_lightning/plugins/training_type/sharded_spawn.py b/pytorch_lightning/plugins/training_type/sharded_spawn.py index 5daf4e5be3735..c583ac756cd0f 100644 --- a/pytorch_lightning/plugins/training_type/sharded_spawn.py +++ b/pytorch_lightning/plugins/training_type/sharded_spawn.py @@ -16,7 +16,7 @@ import torch from torch.optim import Optimizer -from pytorch_lightning.core.lightning import LightningModule +import pytorch_lightning as pl from pytorch_lightning.plugins.precision.sharded_native_amp import ShardedNativeMixedPrecisionPlugin from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin from pytorch_lightning.trainer.states import TrainerFn @@ -71,7 +71,7 @@ def _optim_state_dict(self, optimizer): return optimizer.state_dict() @property - def lightning_module(self) -> LightningModule: + def lightning_module(self) -> 'pl.LightningModule': if not _FAIRSCALE_AVAILABLE: # pragma: no cover raise MisconfigurationException( "`DDPSpawnShardedPlugin` requires `fairscale` to be installed." diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index 1fd9bbcb0a2cf..68e189f6f60cd 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -185,6 +185,9 @@ def transfer_distrib_spawn_state_on_fit_end(self, results): checkpoint_callback = self.lightning_module.trainer.checkpoint_callback best_model_path = checkpoint_callback.best_model_path if checkpoint_callback else None + # requires to compute the state_dict on all processes in case Metrics are present + state_dict = self.lightning_module.state_dict() + if self.mp_queue is not None: rank_zero_warn("cleaning up tpu spawn environment...") @@ -195,7 +198,7 @@ def transfer_distrib_spawn_state_on_fit_end(self, results): and len(best_model_path) > 0 ): last_path = re.sub(".ckpt", ".tmp_end.ckpt", best_model_path) - self.save(self.lightning_module.state_dict(), last_path) + self.save(state_dict, last_path) if self.local_rank == 0: # todo, pass complete checkpoint as state dictionary diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index 1f17308df73b3..4f4e44e57d3a3 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -17,8 +17,8 @@ from inspect import signature from typing import Any, Callable, Dict, List, Optional, Type +import pytorch_lightning as pl from pytorch_lightning.callbacks import Callback -from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.utilities import rank_zero_deprecation, rank_zero_warn from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature from pytorch_lightning.utilities.types import EPOCH_OUTPUT, STEP_OUTPUT @@ -32,19 +32,19 @@ class TrainerCallbackHookMixin(ABC): # this is just a summary on variables used in this abstract class, # the proper values/initialisation should be done in child class callbacks: List[Callback] = [] - lightning_module: LightningModule + lightning_module: 'pl.LightningModule' - def on_before_accelerator_backend_setup(self, model: LightningModule) -> None: + def on_before_accelerator_backend_setup(self, model: 'pl.LightningModule') -> None: """Called at the beginning of fit (train + validate), validate, test, or predict, or tune.""" for callback in self.callbacks: callback.on_before_accelerator_backend_setup(self, model) - def configure_sharded_model(self, model: LightningModule) -> None: + def configure_sharded_model(self, model: 'pl.LightningModule') -> None: """Called at the beginning of fit (train + validate), validate, test, or predict, or tune.""" for callback in self.callbacks: callback.on_configure_sharded_model(self, model) - def setup(self, model: LightningModule, stage: Optional[str]) -> None: + def setup(self, model: 'pl.LightningModule', stage: Optional[str]) -> None: """Called at the beginning of fit (train + validate), validate, test, or predict, or tune.""" for callback in self.callbacks: callback.setup(self, model, stage=stage) diff --git a/pytorch_lightning/trainer/connectors/callback_connector.py b/pytorch_lightning/trainer/connectors/callback_connector.py index 5652a65ee6df0..75cd74b307852 100644 --- a/pytorch_lightning/trainer/connectors/callback_connector.py +++ b/pytorch_lightning/trainer/connectors/callback_connector.py @@ -15,9 +15,9 @@ from datetime import timedelta from typing import Dict, List, Optional, Union +import pytorch_lightning as pl from pytorch_lightning.callbacks import Callback, ModelCheckpoint, ProgressBar, ProgressBarBase from pytorch_lightning.callbacks.timer import Timer -from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.utilities import rank_zero_info from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -137,7 +137,7 @@ def attach_model_logging_functions(self, model): callback.log_dict = model.log_dict @staticmethod - def _attach_model_callbacks(model: LightningModule, trainer) -> None: + def _attach_model_callbacks(model: 'pl.LightningModule', trainer) -> None: """ Attaches the callbacks defined in the model. If a callback returned by the model's configure_callback method has the same type as one or several diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index c2a0411c0df36..f1620c10bbd45 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -19,8 +19,7 @@ import torch -import pytorch_lightning -from pytorch_lightning.core.lightning import LightningModule +import pytorch_lightning as pl from pytorch_lightning.utilities import ( _OMEGACONF_AVAILABLE, DeviceType, @@ -292,8 +291,8 @@ def hpc_save(self, folderpath: str, logger): try: atomic_save(checkpoint, filepath) except AttributeError as err: - if LightningModule.CHECKPOINT_HYPER_PARAMS_KEY in checkpoint: - del checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY] + if pl.LightningModule.CHECKPOINT_HYPER_PARAMS_KEY in checkpoint: + del checkpoint[pl.LightningModule.CHECKPOINT_HYPER_PARAMS_KEY] rank_zero_warn( 'warning, `hyper_parameters` dropped from checkpoint.' f' An attribute is not picklable {err}' @@ -339,7 +338,7 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict: checkpoint = { 'epoch': current_epoch, 'global_step': global_step, - 'pytorch-lightning_version': pytorch_lightning.__version__, + 'pytorch-lightning_version': pl.__version__, 'state_dict': self.trainer.accelerator.lightning_module_state_dict(), } @@ -366,13 +365,13 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict: # dump hyper-parameters if model.hparams: if hasattr(model, '_hparams_name'): - checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_NAME] = model._hparams_name + checkpoint[pl.LightningModule.CHECKPOINT_HYPER_PARAMS_NAME] = model._hparams_name # dump arguments if _OMEGACONF_AVAILABLE and isinstance(model.hparams, Container): - checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY] = model.hparams - checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_TYPE] = type(model.hparams) + checkpoint[pl.LightningModule.CHECKPOINT_HYPER_PARAMS_KEY] = model.hparams + checkpoint[pl.LightningModule.CHECKPOINT_HYPER_PARAMS_TYPE] = type(model.hparams) else: - checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY] = dict(model.hparams) + checkpoint[pl.LightningModule.CHECKPOINT_HYPER_PARAMS_KEY] = dict(model.hparams) # give the model a chance to dump a few things model.on_save_checkpoint(checkpoint) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index 5e08a82e4bf7e..cbed7368a4372 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -14,15 +14,15 @@ from collections.abc import Generator from dataclasses import asdict, dataclass, replace from functools import partial, wraps -from typing import Any, Callable, Dict, Iterable, Mapping, Optional, Tuple, Union +from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Tuple, Union import torch from torchmetrics import Metric -from pytorch_lightning.trainer.connectors.logger_connector.fx_validator import FxValidator from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.apply_func import apply_to_collection, apply_to_collections from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin +from pytorch_lightning.utilities.distributed import distributed_available from pytorch_lightning.utilities.enums import LightningEnum from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.metrics import metrics_to_scalars @@ -46,6 +46,7 @@ class MetricSource(LightningEnum): class _Sync: fn: Optional[Callable] = None should: bool = False + rank_zero_only: bool = False op: Optional[str] = None group: Optional[Any] = None @@ -55,7 +56,10 @@ def __post_init__(self) -> None: @property def __call__(self) -> Any: - return partial(self.fn, reduce_op=self.op, group=self.group) if self.should else self.no_op + return ( + partial(self.fn, reduce_op=self.op, group=self.group) + if self.should and not self.rank_zero_only else self.no_op + ) @staticmethod def no_op(value: Any, *_, **__) -> Any: @@ -73,6 +77,7 @@ class _Metadata: _reduce_fx: Callable = torch.mean enable_graph: bool = False dataloader_idx: Optional[int] = None + metric_attribute: Optional[str] = None _sync: Optional[_Sync] = None @property @@ -165,9 +170,9 @@ def __init__(self, metadata: _Metadata, is_tensor: bool) -> None: self.meta = metadata self.has_reset = False if is_tensor: - self.add_state("value", torch.tensor(0, dtype=torch.float)) + self.add_state("value", torch.tensor(0, dtype=torch.float), dist_reduce_fx=torch.sum) if self.meta.is_mean_reduction: - self.add_state("cumulated_batch_size", torch.tensor(0, dtype=torch.float)) + self.add_state("cumulated_batch_size", torch.tensor(0, dtype=torch.float), dist_reduce_fx=torch.sum) def update(self, value: _METRIC, batch_size: torch.Tensor) -> None: if self.is_tensor: @@ -238,13 +243,22 @@ def __setattr__(self, key: str, value: Any) -> None: object.__setattr__(self, key, value) def __repr__(self) -> str: - state = f"value={self.value}" + state = f"{repr(self.meta.name)}, value={self.value}" if self.is_tensor and self.meta.is_mean_reduction: state += f", cumulated_batch_size={self.cumulated_batch_size}" return f"{self.__class__.__name__}({state})" - def __getstate__(self) -> dict: - d = super().__getstate__() + def __getstate__(self, drop_value: bool = False) -> dict: + skip = ['update', 'compute', '_update_signature'] + if not self.is_tensor and drop_value: + # Avoid serializing ResultMetrics which are passed Metrics + skip.append('value') + with self.sync_context( + should_sync=not self.meta.sync.rank_zero_only, + process_group=self.meta.sync.group, + distributed_available=distributed_available + ): + d = {k: v for k, v in self.__dict__.items() if k not in skip} d['meta'] = d['meta'].__getstate__() d['_class'] = self.__class__.__name__ return d @@ -275,10 +289,10 @@ def __init__(self, *args, metadata: Optional[_Metadata] = None) -> None: super().__init__(*args) self.meta = metadata - def __getstate__(self) -> dict: + def __getstate__(self, drop_value: bool = False) -> dict: def getstate(item: ResultMetric) -> dict: - return item.__getstate__() + return item.__getstate__(drop_value=drop_value) items = apply_to_collection(dict(self), (ResultMetric, ResultMetricCollection), getstate) return {"items": items, "meta": self.meta.__getstate__(), "_class": self.__class__.__name__} @@ -331,7 +345,17 @@ def __init__(self, training: bool, device: Optional[Union[str, torch.device]] = self._minimize = None self._batch_size = torch.tensor(1, device=device) self.device: Optional[Union[str, torch.device]] = device - self.fx_validator = FxValidator() + + @property + def result_metrics(self) -> List[ResultMetric]: + o = [] + + def append_fn(v: ResultMetric) -> None: + nonlocal o + o.append(v) + + apply_to_collection(list(self.values()), ResultMetric, append_fn) + return o @property def batch_size(self) -> torch.Tensor: @@ -398,6 +422,8 @@ def log( sync_dist_group: Optional[Any] = None, dataloader_idx: Optional[int] = None, batch_size: Optional[int] = None, + metric_attribute: Optional[str] = None, + rank_zero_only: bool = False, ) -> None: """See :meth:`~pytorch_lightning.core.lightning.LightningModule.log`""" # no metrics should be logged with graphs @@ -424,16 +450,21 @@ def log( on_epoch=on_epoch, enable_graph=enable_graph, dataloader_idx=dataloader_idx, + metric_attribute=metric_attribute, ) meta.reduce_fx = reduce_fx meta.sync = _Sync( should=sync_dist, fn=sync_dist_fn, group=sync_dist_group, + rank_zero_only=rank_zero_only, ) + # register logged value if it doesn't exist if key not in self: self.register_key(key, meta, value) + + # check the stored metadata and the current one match elif meta != self[key].meta: raise MisconfigurationException( f'You called `self.log({name}, ...)` twice in `{fx}` with different arguments. This is not allowed' @@ -472,7 +503,11 @@ def _get_cache(result_metric: ResultMetric, on_step: bool) -> Optional[torch.Ten cache = result_metric._forward_cache elif not on_step and result_metric.meta.on_epoch: if not result_metric._computed: + # always reduce on epoch end + should = result_metric.meta.sync.should + result_metric.meta.sync.should = True result_metric.compute() + result_metric.meta.sync.should = should cache = result_metric._computed if cache is not None and not result_metric.meta.enable_graph: return cache.detach() @@ -597,20 +632,28 @@ def cpu(self) -> 'ResultCollection': def __str__(self) -> str: return f'{self.__class__.__name__}({self.training}, {self.device}, {repr(self)})' - def __getstate__(self) -> dict: + def __getstate__(self, drop_value: bool = True) -> dict: d = self.__dict__.copy() + # can't deepcopy tensors with grad_fn minimize = d['_minimize'] if minimize is not None: d['_minimize'] = minimize.detach() + extra = self.get('_extra') if extra is not None: d['_extra'] = extra + # all the items should be either `ResultMetric`s or `ResultMetricCollection`s - items = {k: v.__getstate__() for k, v in self.items() if k != '_extra'} + items = {k: v.__getstate__(drop_value=drop_value) for k, v in self.items() if k != '_extra'} return {**d, 'items': items} - def __setstate__(self, state: dict, map_location: Optional[Union[str, torch.device]] = None) -> None: + def __setstate__( + self, + state: dict, + map_location: Optional[Union[str, torch.device]] = None, + sync_fn: Optional[Callable] = None, + ) -> None: self.__dict__.update({k: v for k, v in state.items() if k != 'items'}) def setstate(k: str, item: dict) -> Union[ResultMetric, ResultMetricCollection]: @@ -623,8 +666,8 @@ def setstate(k: str, item: dict) -> Union[ResultMetric, ResultMetricCollection]: cls = ResultMetricCollection else: raise ValueError(f"Unexpected class name: {cls}") - sync_fn = self[k].meta.sync.fn if k in self else None - return cls._reconstruct(item, sync_fn=sync_fn) + _sync_fn = sync_fn or (self[k].meta.sync.fn if k in self else None) + return cls._reconstruct(item, sync_fn=_sync_fn) items = {k: setstate(k, v) for k, v in state['items'].items()} self.update(items) @@ -632,8 +675,22 @@ def setstate(k: str, item: dict) -> Union[ResultMetric, ResultMetricCollection]: device = map_location or self.device self.to(device) - def state_dict(self) -> dict: - return self.__getstate__() + def state_dict(self, drop_value: bool = True) -> dict: + return self.__getstate__(drop_value) - def load_state_dict(self, state_dict: dict, map_location: Optional[Union[str, torch.device]] = None) -> None: - self.__setstate__(state_dict, map_location=map_location) + def load_state_dict( + self, + state_dict: dict, + map_location: Optional[Union[str, torch.device]] = None, + sync_fn: Optional[Callable] = None, + metrics: Optional[Dict[str, Metric]] = None, + ) -> None: + self.__setstate__(state_dict, map_location=map_location, sync_fn=sync_fn) + + if not metrics: + return + result_metrics = self.result_metrics + for metric_attribute, metric in metrics.items(): + for result_metric in result_metrics: + if result_metric.meta.metric_attribute == metric_attribute: + result_metric.value = metric diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index c9b8a6f29652b..ce6caa4e2f330 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -22,8 +22,8 @@ from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler from torch.utils.data.distributed import DistributedSampler +import pytorch_lightning as pl from pytorch_lightning.accelerators import Accelerator -from pytorch_lightning.core import LightningModule from pytorch_lightning.overrides.distributed import IndexBatchSamplerWrapper, UnrepeatedDistributedSampler from pytorch_lightning.trainer.connectors.accelerator_connector import AcceleratorConnector from pytorch_lightning.trainer.states import RunningStage @@ -226,7 +226,7 @@ def _get_distributed_sampler( sampler = cls(dataloader.dataset, **kwargs) return sampler - def reset_train_dataloader(self, model: LightningModule) -> None: + def reset_train_dataloader(self, model: 'pl.LightningModule') -> None: """Resets the train dataloader and initialises required variables (number of batches, when to validate, etc.). @@ -312,7 +312,7 @@ def reset_train_dataloader(self, model: LightningModule) -> None: def _reset_eval_dataloader( self, - model: LightningModule, + model: 'pl.LightningModule', mode: str, ) -> Tuple[List[Union[int, float]], List[DataLoader]]: """Generic method to reset a dataloader for evaluation. @@ -412,7 +412,7 @@ def _reset_eval_dataloader( return loader_num_batches, dataloaders - def reset_val_dataloader(self, model: LightningModule) -> None: + def reset_val_dataloader(self, model: 'pl.LightningModule') -> None: """Resets the validation dataloader and determines the number of batches. Args: @@ -457,7 +457,7 @@ def reset_train_val_dataloaders(self, model) -> None: if self.val_dataloaders is None: self.reset_val_dataloader(model) - def request_dataloader(self, model: LightningModule, stage: str) -> DataLoader: + def request_dataloader(self, model: 'pl.LightningModule', stage: str) -> DataLoader: """Handles downloading data in the GPU or TPU case. Args: diff --git a/pytorch_lightning/trainer/model_hooks.py b/pytorch_lightning/trainer/model_hooks.py index cbf331913e597..2336379fc3d49 100644 --- a/pytorch_lightning/trainer/model_hooks.py +++ b/pytorch_lightning/trainer/model_hooks.py @@ -15,7 +15,7 @@ from abc import ABC from typing import Optional -from pytorch_lightning.core.lightning import LightningModule +import pytorch_lightning as pl from pytorch_lightning.utilities import rank_zero_deprecation from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature @@ -27,9 +27,9 @@ class TrainerModelHooksMixin(ABC): Use the utilities from ``pytorch_lightning.utilities.signature_utils`` instead. """ - lightning_module: LightningModule + lightning_module: 'pl.LightningModule' - def is_function_implemented(self, f_name: str, model: Optional[LightningModule] = None) -> bool: + def is_function_implemented(self, f_name: str, model: Optional['pl.LightningModule'] = None) -> bool: rank_zero_deprecation( "Internal: TrainerModelHooksMixin.is_function_implemented is deprecated in v1.4" " and will be removed in v1.6." diff --git a/pytorch_lightning/trainer/optimizers.py b/pytorch_lightning/trainer/optimizers.py index b5afe7bf75168..80ec5857de287 100644 --- a/pytorch_lightning/trainer/optimizers.py +++ b/pytorch_lightning/trainer/optimizers.py @@ -19,7 +19,7 @@ from torch import optim from torch.optim.optimizer import Optimizer -from pytorch_lightning.core.lightning import LightningModule +import pytorch_lightning as pl from pytorch_lightning.core.optimizer import LightningOptimizer from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -29,7 +29,7 @@ class TrainerOptimizersMixin(ABC): _lightning_optimizers: Optional[List[LightningOptimizer]] - def init_optimizers(self, model: LightningModule) -> Tuple[List, List, List]: + def init_optimizers(self, model: 'pl.LightningModule') -> Tuple[List, List, List]: self._lightning_optimizers = None optim_conf = model.configure_optimizers() if optim_conf is None: diff --git a/pytorch_lightning/trainer/properties.py b/pytorch_lightning/trainer/properties.py index b77b1b8268b9a..d9620112479f2 100644 --- a/pytorch_lightning/trainer/properties.py +++ b/pytorch_lightning/trainer/properties.py @@ -21,11 +21,11 @@ import torch from torch.optim import Optimizer +import pytorch_lightning as pl from pytorch_lightning.accelerators import Accelerator from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, ProgressBarBase from pytorch_lightning.callbacks.base import Callback from pytorch_lightning.callbacks.prediction_writer import BasePredictionWriter -from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.core.optimizer import LightningOptimizer from pytorch_lightning.loggers import LightningLoggerBase from pytorch_lightning.loggers.tensorboard import TensorBoardLogger @@ -146,7 +146,7 @@ def data_parallel_device_ids(self) -> Optional[List[int]]: return self.accelerator_connector.parallel_device_ids @property - def lightning_module(self) -> LightningModule: + def lightning_module(self) -> 'pl.LightningModule': return self.accelerator.lightning_module @property @@ -277,7 +277,7 @@ def progress_bar_callback(self) -> Optional[ProgressBarBase]: def progress_bar_dict(self) -> dict: """ Read-only for progress bar metrics. """ ref_model = self.lightning_module - ref_model = cast(LightningModule, ref_model) + ref_model = cast(pl.LightningModule, ref_model) standard_metrics = ref_model.get_progress_bar_dict() pbar_metrics = self.progress_bar_metrics diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index c5ee90cd126ce..4d097a2de2763 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -21,10 +21,10 @@ import torch +import pytorch_lightning as pl from pytorch_lightning.accelerators import Accelerator from pytorch_lightning.callbacks import Callback from pytorch_lightning.core.datamodule import LightningDataModule -from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.core.memory import ModelSummary from pytorch_lightning.loggers import LightningLoggerBase from pytorch_lightning.loops.dataloader.evaluation_loop import EvaluationLoop @@ -470,7 +470,7 @@ def _setup_fit(self, model, train_dataloader=None, val_dataloaders=None, datamod def fit( self, - model: LightningModule, + model: 'pl.LightningModule', train_dataloaders: Optional[Union[TRAIN_DATALOADERS, LightningDataModule]] = None, val_dataloaders: Optional[EVAL_DATALOADERS] = None, datamodule: Optional[LightningDataModule] = None, @@ -526,7 +526,7 @@ def fit( def validate( self, - model: Optional[LightningModule] = None, + model: Optional['pl.LightningModule'] = None, dataloaders: Optional[Union[EVAL_DATALOADERS, LightningDataModule]] = None, ckpt_path: Optional[str] = 'best', verbose: bool = True, @@ -602,7 +602,7 @@ def validate( def test( self, - model: Optional[LightningModule] = None, + model: Optional['pl.LightningModule'] = None, dataloaders: Optional[Union[EVAL_DATALOADERS, LightningDataModule]] = None, ckpt_path: Optional[str] = 'best', verbose: bool = True, @@ -677,7 +677,7 @@ def test( def predict( self, - model: Optional[LightningModule] = None, + model: Optional['pl.LightningModule'] = None, dataloaders: Optional[Union[EVAL_DATALOADERS, LightningDataModule]] = None, datamodule: Optional[LightningDataModule] = None, return_predictions: Optional[bool] = None, @@ -747,7 +747,7 @@ def predict( def tune( self, - model: LightningModule, + model: 'pl.LightningModule', train_dataloaders: Optional[Union[TRAIN_DATALOADERS, LightningDataModule]] = None, val_dataloaders: Optional[EVAL_DATALOADERS] = None, datamodule: Optional[LightningDataModule] = None, @@ -807,7 +807,7 @@ def tune( return result - def _run(self, model: LightningModule) -> Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]]: + def _run(self, model: 'pl.LightningModule') -> Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]]: # clean hparams if hasattr(model, "hparams"): parsing.clean_namespace(model.hparams) @@ -1090,7 +1090,7 @@ def __load_ckpt_weights(self, ckpt_path: Optional[str]) -> Optional[str]: self.checkpoint_connector.restore_model_weights(ckpt_path) return ckpt_path - def _call_setup_hook(self, model: LightningModule) -> None: + def _call_setup_hook(self, model: 'pl.LightningModule') -> None: fn = self.state.fn._setup_fn self.accelerator.barrier("pre_setup") @@ -1102,7 +1102,7 @@ def _call_setup_hook(self, model: LightningModule) -> None: self.accelerator.barrier("post_setup") - def _call_configure_sharded_model(self, model: LightningModule) -> None: + def _call_configure_sharded_model(self, model: 'pl.LightningModule') -> None: # Call configure sharded model hook if accelerator requests. In some cases # we will not call the hook; the hook has initialized the sharded model for example. @@ -1115,7 +1115,7 @@ def _call_configure_sharded_model(self, model: LightningModule) -> None: model.call_configure_sharded_model_hook = True self.accelerator.call_configure_sharded_model_hook = False - def _call_teardown_hook(self, model: LightningModule) -> None: + def _call_teardown_hook(self, model: 'pl.LightningModule') -> None: fn = self.state.fn._setup_fn if self.datamodule is not None: @@ -1126,6 +1126,8 @@ def _call_teardown_hook(self, model: LightningModule) -> None: model._current_fx_name = None model._current_dataloader_idx = None + # these could have become stale if metrics are defined in `setup` + model._metric_attributes = None def call_hook(self, hook_name: str, *args, **kwargs) -> Any: # Note this implementation is copy/pasted into the TrainLoop class in TrainingEpochLoop._on_train_epoch_end_hook diff --git a/pytorch_lightning/trainer/training_tricks.py b/pytorch_lightning/trainer/training_tricks.py index a45c9436dbdb7..beecc5e2a764d 100644 --- a/pytorch_lightning/trainer/training_tricks.py +++ b/pytorch_lightning/trainer/training_tricks.py @@ -18,7 +18,7 @@ import torch from torch import Tensor -from pytorch_lightning.core.lightning import LightningModule +import pytorch_lightning as pl from pytorch_lightning.utilities import rank_zero_deprecation from pytorch_lightning.utilities.finite_checks import detect_nan_parameters, print_nan_gradients @@ -34,7 +34,7 @@ class TrainerTrainingTricksMixin(ABC): # this is just a summary on variables used in this abstract class, # the proper values/initialisation should be done in child class - lightning_module: LightningModule + lightning_module: 'pl.LightningModule' def print_nan_gradients(self) -> None: rank_zero_deprecation( diff --git a/pytorch_lightning/utilities/distributed.py b/pytorch_lightning/utilities/distributed.py index ae977bd03bac8..5094f55ba59f8 100644 --- a/pytorch_lightning/utilities/distributed.py +++ b/pytorch_lightning/utilities/distributed.py @@ -135,6 +135,10 @@ def gather_all_tensors(result: Union[torch.Tensor], group: Optional[Any] = None) return gathered_result +def distributed_available() -> bool: + return torch.distributed.is_available() and torch.distributed.is_initialized() or tpu_distributed() + + def sync_ddp_if_available( result: Union[torch.Tensor], group: Optional[Any] = None, @@ -151,7 +155,7 @@ def sync_ddp_if_available( Return: reduced value """ - if torch.distributed.is_available() and torch.distributed.is_initialized() or tpu_distributed(): + if distributed_available(): return sync_ddp(result, group=group, reduce_op=reduce_op) return result @@ -230,7 +234,7 @@ def all_gather_ddp_if_available( A tensor of shape (world_size, batch, ...) """ group = group if group is not None else torch.distributed.group.WORLD - if torch.distributed.is_available() and torch.distributed.is_initialized(): + if distributed_available(): if sync_grads: return AllGatherGrad.apply(tensor, group) else: diff --git a/pytorch_lightning/utilities/model_helpers.py b/pytorch_lightning/utilities/model_helpers.py index b7c3c09aff60b..e52f8efa2689f 100644 --- a/pytorch_lightning/utilities/model_helpers.py +++ b/pytorch_lightning/utilities/model_helpers.py @@ -15,8 +15,7 @@ from typing import Optional, Type, Union from unittest.mock import Mock -from pytorch_lightning.core.datamodule import LightningDataModule -from pytorch_lightning.core.lightning import LightningModule +import pytorch_lightning as pl from pytorch_lightning.utilities import rank_zero_deprecation @@ -24,7 +23,7 @@ def is_overridden( method_name: str, instance: Optional[object] = None, parent: Optional[Type[object]] = None, - model: Optional[Union[LightningModule, LightningDataModule]] = None, + model: Optional[Union['pl.LightningModule', 'pl.LightningDataModule']] = None, ) -> bool: if model is not None and instance is None: rank_zero_deprecation( @@ -38,10 +37,10 @@ def is_overridden( return False if parent is None: - if isinstance(instance, LightningModule): - parent = LightningModule - elif isinstance(instance, LightningDataModule): - parent = LightningDataModule + if isinstance(instance, pl.LightningModule): + parent = pl.LightningModule + elif isinstance(instance, pl.LightningDataModule): + parent = pl.LightningDataModule if parent is None: raise ValueError("Expected a parent") diff --git a/requirements.txt b/requirements.txt index 001bb24219597..e6b373036675e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,7 +7,7 @@ tqdm>=4.41.0 PyYAML>=5.1,<=5.4.1 fsspec[http]>=2021.05.0, !=2021.06.0 tensorboard>=2.2.0, !=2.5.0 # 2.5.0 GPU CI error: 'Couldn't build proto file into descriptor pool!' -torchmetrics>=0.3.2 +torchmetrics>=0.4.0rc1 pyDeprecate==0.3.1 packaging>=17.0 typing-extensions # TypedDict support for python<3.8 diff --git a/tests/checkpointing/test_checkpoint_callback_frequency.py b/tests/checkpointing/test_checkpoint_callback_frequency.py index 9fdd69dba7a9a..0073676a77eec 100644 --- a/tests/checkpointing/test_checkpoint_callback_frequency.py +++ b/tests/checkpointing/test_checkpoint_callback_frequency.py @@ -120,7 +120,7 @@ def training_step(self, batch, batch_idx): def training_epoch_end(self, outputs) -> None: local_rank = int(os.getenv("LOCAL_RANK")) if self.trainer.is_global_zero: - self.log('my_loss_2', (1 + local_rank), on_epoch=True) + self.log('my_loss_2', (1 + local_rank), on_epoch=True, rank_zero_only=True) data = str(self.global_rank) obj = [[data], (data, ), set(data)] out = self.trainer.training_type_plugin.broadcast(obj) diff --git a/tests/core/test_metric_result_integration.py b/tests/core/test_metric_result_integration.py index 6b7163c4aa643..7471914886a27 100644 --- a/tests/core/test_metric_result_integration.py +++ b/tests/core/test_metric_result_integration.py @@ -135,9 +135,9 @@ def test_result_metric_integration(): assert str(result) == ( "ResultCollection(True, cpu, {" - "'h.a': ResultMetric(value=DummyMetric()), " - "'h.b': ResultMetric(value=DummyMetric()), " - "'h.c': ResultMetric(value=DummyMetric())" + "'h.a': ResultMetric('a', value=DummyMetric()), " + "'h.b': ResultMetric('b', value=DummyMetric()), " + "'h.c': ResultMetric('c', value=DummyMetric())" "})" ) @@ -184,7 +184,7 @@ def lightning_log(fx, *args, **kwargs): assert result[k].cumulated_batch_size == torch.tensor(1.), k -def my_sync_dist(x): +def my_sync_dist(x, *_, **__): return x @@ -208,7 +208,7 @@ def lightning_log(fx, *args, **kwargs): result.log(fx, *args, **kwargs, sync_dist_fn=my_sync_dist) current_fx_name = fx - for _ in range(2): + for epoch in range(2): cumulative_sum = 0 @@ -222,9 +222,9 @@ def lightning_log(fx, *args, **kwargs): cumulative_sum += i metric = metric_a if i < 1 else metric_d - lightning_log('training_step', 'a', metric, on_step=True, on_epoch=True) - lightning_log('training_step', 'b', metric_b, on_step=False, on_epoch=True) - lightning_log('training_step', 'c', metric_c, on_step=True, on_epoch=False) + lightning_log('training_step', 'a', metric, on_step=True, on_epoch=True, metric_attribute="metric") + lightning_log('training_step', 'b', metric_b, on_step=False, on_epoch=True, metric_attribute="metric_b") + lightning_log('training_step', 'c', metric_c, on_step=True, on_epoch=False, metric_attribute="metric_c") lightning_log('training_step', 'a_1', a, on_step=True, on_epoch=True) lightning_log('training_step', 'b_1', b, on_step=False, on_epoch=True) lightning_log('training_step', 'c_1', {'1': c, '2': c}, on_step=True, on_epoch=False) @@ -238,7 +238,17 @@ def lightning_log(fx, *args, **kwargs): state_dict = result.state_dict() # check the sync fn was dropped assert 'fn' not in state_dict['items']['training_step.a']['meta']['_sync'] - new_result.load_state_dict(state_dict) + + assert not new_result.result_metrics + assert len(result.result_metrics) == 7 + epoch > 0 + + new_result.load_state_dict( + state_dict, metrics={ + "metric": metric, + "metric_b": metric_b, + "metric_c": metric_c + } + ) # should match assert result_copy == new_result # the sync fn has been kept @@ -290,7 +300,8 @@ def validation_step(self, batch, batch_idx): def on_save_checkpoint(self, checkpoint) -> None: results = self.trainer._results - state_dict = results.state_dict() + # simplify logic + state_dict = results.state_dict(drop_value=False) # check device assert results['validation_step.v'].value.device.type == device diff --git a/tests/metrics/test_remove_1-5_metrics.py b/tests/metrics/test_remove_1-5_metrics.py index d3703bf3691c9..aa7d4977d1133 100644 --- a/tests/metrics/test_remove_1-5_metrics.py +++ b/tests/metrics/test_remove_1-5_metrics.py @@ -215,7 +215,7 @@ def test_v1_5_metric_classif_mix(): preds = torch.tensor([0, 1, 0, 0]) confusion_matrix._warned = False with pytest.deprecated_call(match='It will be removed in v1.5.0'): - assert torch.equal(confusion_matrix(preds, target, num_classes=2), torch.tensor([[2., 0.], [1., 1.]])) + assert torch.equal(confusion_matrix(preds, target, num_classes=2).float(), torch.tensor([[2., 0.], [1., 1.]])) target = torch.tensor([0, 1, 2, 0, 1, 2]) preds = torch.tensor([0, 2, 1, 0, 0, 1]) diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index 142505ecde890..9a689fe9d725a 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -530,7 +530,7 @@ def test_trainer_model_hook_system_fit_no_val_and_resume(tmpdir): 'lr_schedulers': ANY, 'optimizer_states': ANY, 'pytorch-lightning_version': __version__, - 'state_dict': ANY + 'state_dict': ANY, }, ) ), dict(name='configure_sharded_model'), @@ -799,7 +799,7 @@ def call(hook, fn, *args, **kwargs): 'lr_schedulers': ANY, 'optimizer_states': ANY, 'pytorch-lightning_version': __version__, - 'state_dict': ANY + 'state_dict': ANY, }, ) ), dict(name='teardown', kwargs=dict(stage='fit')), diff --git a/tests/special_tests.sh b/tests/special_tests.sh index a87f50548d06b..f76ac419ea866 100755 --- a/tests/special_tests.sh +++ b/tests/special_tests.sh @@ -17,7 +17,7 @@ set -e # this environment variable allows special tests to run export PL_RUNNING_SPECIAL_TESTS=1 # python arguments -defaults='-m coverage run --source pytorch_lightning --append -m pytest --verbose --capture=no --disable-warnings' +defaults='-m coverage run --source pytorch_lightning --append -m pytest --durations=0 --capture=no --disable-warnings' # find tests marked as `@RunIf(special=True)` grep_output=$(grep --recursive --line-number --word-regexp 'tests' 'benchmarks' --regexp 'special=True') diff --git a/tests/trainer/logging_/test_logger_connector.py b/tests/trainer/logging_/test_logger_connector.py index 5716b8b5f07f0..592fde1569344 100644 --- a/tests/trainer/logging_/test_logger_connector.py +++ b/tests/trainer/logging_/test_logger_connector.py @@ -23,7 +23,6 @@ from pytorch_lightning.trainer import Trainer from pytorch_lightning.trainer.connectors.logger_connector.fx_validator import FxValidator from pytorch_lightning.trainer.connectors.logger_connector.result import MetricSource, ResultCollection -from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers.boring_model import BoringModel, RandomDataset from tests.helpers.runif import RunIf @@ -293,48 +292,63 @@ def __init__(self): super().__init__() self.layer = torch.nn.Linear(32, 1) - for stage in ['train', 'val', 'test']: - acc = Accuracy() - acc.reset = mock.Mock(side_effect=acc.reset) - ap = AveragePrecision(num_classes=1, pos_label=1) - ap.reset = mock.Mock(side_effect=ap.reset) - self.add_module(f"acc_{stage}", acc) - self.add_module(f"ap_{stage}", ap) + def _create_metrics(self): + acc = Accuracy() + acc.reset = mock.Mock(side_effect=acc.reset) + ap = AveragePrecision(num_classes=1, pos_label=1) + ap.reset = mock.Mock(side_effect=ap.reset) + return acc, ap + + def setup(self, stage): + fn = stage + if fn == 'fit': + for stage in ('train', 'validate'): + acc, ap = self._create_metrics() + self.add_module(f"acc_{fn}_{stage}", acc) + self.add_module(f"ap_{fn}_{stage}", ap) + else: + acc, ap = self._create_metrics() + stage = self.trainer.state.stage + self.add_module(f"acc_{fn}_{stage}", acc) + self.add_module(f"ap_{fn}_{stage}", ap) def forward(self, x): return self.layer(x) - def _step(self, stage, batch): - labels = (batch.detach().sum(1) > 0).float() # Fake some targets - logits = self.forward(batch) - loss = torch.nn.functional.binary_cross_entropy_with_logits(logits, labels.unsqueeze(1)) - probs = torch.sigmoid(logits.detach()) - self.log(f"loss/{stage}", loss) + def _step(self, batch): + fn, stage = self.trainer.state.fn, self.trainer.state.stage + + logits = self(batch) + loss = logits.sum() + self.log(f"loss/{fn}_{stage}", loss) - acc = self._modules[f"acc_{stage}"] - ap = self._modules[f"ap_{stage}"] + acc = self._modules[f"acc_{fn}_{stage}"] + ap = self._modules[f"ap_{fn}_{stage}"] - labels_int = labels.to(torch.long) - acc(probs.flatten(), labels_int) - ap(probs.flatten(), labels_int) + preds = torch.rand(len(batch)) # Fake preds + labels = torch.randint(0, 1, [len(batch)]) # Fake targets + acc(preds, labels) + ap(preds, labels) # Metric.forward calls reset so reset the mocks here acc.reset.reset_mock() ap.reset.reset_mock() - self.log(f"{stage}/accuracy", acc) - self.log(f"{stage}/ap", ap) + self.log(f"acc/{fn}_{stage}", acc) + self.log(f"ap/{fn}_{stage}", ap) return loss def training_step(self, batch, batch_idx, *args, **kwargs): - return self._step('train', batch) + return self._step(batch) def validation_step(self, batch, batch_idx, *args, **kwargs): - return self._step('val', batch) + if self.trainer.sanity_checking: + return + return self._step(batch) def test_step(self, batch, batch_idx, *args, **kwargs): - return self._step('test', batch) + return self._step(batch) def configure_optimizers(self): optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1) @@ -350,33 +364,11 @@ def val_dataloader(self): def test_dataloader(self): return DataLoader(RandomDataset(32, 64)) - def _assert_epoch_end(self, stage): - acc = self._modules[f"acc_{stage}"] - ap = self._modules[f"ap_{stage}"] - - acc.reset.assert_called_once() - ap.reset.assert_called_once() - - def teardown(self, stage): - if stage == TrainerFn.FITTING: - self._assert_epoch_end('train') - self._assert_epoch_end('val') - - elif stage == TrainerFn.VALIDATING: - self._assert_epoch_end('val') - - elif stage == TrainerFn.TESTING: - self._assert_epoch_end('test') - - def _assert_called(model, stage): - acc = model._modules[f"acc_{stage}"] - ap = model._modules[f"ap_{stage}"] - - assert acc.reset.call_count == 1 - acc.reset.reset_mock() - - assert ap.reset.call_count == 1 - ap.reset.reset_mock() + def _assert_called(model, fn, stage): + acc = model._modules[f"acc_{fn}_{stage}"] + ap = model._modules[f"ap_{fn}_{stage}"] + acc.reset.assert_called_once() + ap.reset.assert_called_once() model = TestModel() trainer = Trainer( @@ -387,17 +379,18 @@ def _assert_called(model, stage): max_epochs=1, progress_bar_refresh_rate=0, num_sanity_val_steps=2, + checkpoint_callback=False, ) trainer.fit(model) - _assert_called(model, 'train') - _assert_called(model, 'val') + _assert_called(model, 'fit', 'train') + _assert_called(model, 'fit', 'validate') trainer.validate(model) - _assert_called(model, 'val') + _assert_called(model, 'validate', 'validate') trainer.test(model) - _assert_called(model, 'test') + _assert_called(model, 'test', 'test') def test_result_collection_on_tensor_with_mean_reduction():