Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
108 commits
Select commit Hold shift + click to select a range
4cb7e89
add metric reload
tchaton Jun 11, 2021
4176447
add tests
tchaton Jun 11, 2021
9594653
update changelog
tchaton Jun 11, 2021
0fa64ed
udpate
tchaton Jun 11, 2021
9828e72
remove print
tchaton Jun 11, 2021
f85d590
remove attribute_name
tchaton Jun 14, 2021
31d390d
update
tchaton Jun 14, 2021
e7644de
update
tchaton Jun 14, 2021
3a1019e
updat
tchaton Jun 14, 2021
bd4b23b
update
tchaton Jun 14, 2021
659a25a
resolve test
tchaton Jun 14, 2021
8ab34ce
updat
tchaton Jun 14, 2021
dafba9b
update
tchaton Jun 14, 2021
357147b
Merge branch 'fault_tolerant_log_2/n' of https://github.com/PyTorchLi…
tchaton Jun 14, 2021
b774c34
update
tchaton Jun 14, 2021
c453994
update on comments
tchaton Jun 14, 2021
b811478
Merge branch 'fault_tolerant_log' into fault_tolerant_log_2/n
tchaton Jun 14, 2021
9f46a99
add test
tchaton Jun 14, 2021
d80eb00
remove tmp.p
tchaton Jun 14, 2021
c102176
bypass typing bug
tchaton Jun 14, 2021
533f658
Merge branch 'fault_tolerant_log' into fault_tolerant_log_2/n
tchaton Jun 14, 2021
cc23140
add deepcopy to keep sync_fn target
tchaton Jun 14, 2021
d152985
add changelog
tchaton Jun 14, 2021
b909012
remove test changes
tchaton Jun 14, 2021
2c13e5e
typo
tchaton Jun 14, 2021
23854b3
Merge branch 'master' into fault_tolerant_log_2/n
tchaton Jun 17, 2021
1580696
add result collection
tchaton Jun 17, 2021
5331b8e
Update pytorch_lightning/trainer/connectors/logger_connector/result.py
tchaton Jun 17, 2021
e966cfd
update
tchaton Jun 17, 2021
0cd7ba1
Merge branch 'fault_tolerant_log_2/n' of https://github.com/PyTorchLi…
tchaton Jun 17, 2021
40ef8d8
add support for metric and reduction
tchaton Jun 17, 2021
b90916e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 17, 2021
5f3e4b3
wip
tchaton Jun 17, 2021
f843d29
update
tchaton Jun 17, 2021
f48cca5
Merge branch 'fault_tolerant_log_2/n' of https://github.com/PyTorchLi…
tchaton Jun 17, 2021
f6be5d7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 17, 2021
a117e41
update on comments
tchaton Jun 17, 2021
ffb1232
Merge branch 'fault_tolerant_log_2/n' of https://github.com/PyTorchLi…
tchaton Jun 17, 2021
c455c69
improve test
tchaton Jun 18, 2021
6da2da3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 18, 2021
37e5310
resolve test
tchaton Jun 18, 2021
7aecafe
Merge branch 'fault_tolerant_log_2/n' of https://github.com/PyTorchLi…
tchaton Jun 18, 2021
49b2647
test with torchmetrics
tchaton Jun 18, 2021
e22578d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 18, 2021
c312828
update on comments
tchaton Jun 18, 2021
7c2d63c
update torchmetrics path
tchaton Jun 18, 2021
fd30a99
Merge branch 'fault_tolerant_log_2/n' of https://github.com/PyTorchLi…
tchaton Jun 18, 2021
15746fb
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 18, 2021
4c86636
update
tchaton Jun 18, 2021
27c0de7
Merge branch 'fault_tolerant_log_2/n' of https://github.com/PyTorchLi…
tchaton Jun 18, 2021
e560406
update
tchaton Jun 18, 2021
d0b012f
update setup
tchaton Jun 18, 2021
09f05f6
add directly in CI
tchaton Jun 18, 2021
d0a6cf9
update
tchaton Jun 18, 2021
a5ae2c4
Whitespace
carmocca Jun 18, 2021
f7eafd7
resolve bug
tchaton Jun 18, 2021
aa69e61
Merge branch 'fault_tolerant_log_2/n' of https://github.com/PyTorchLi…
tchaton Jun 18, 2021
5f3b33c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 18, 2021
ecbd5e6
update
tchaton Jun 21, 2021
f3a3996
update on comments
tchaton Jun 21, 2021
2289712
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 21, 2021
0415498
update torchmetrics
tchaton Jun 21, 2021
bffc78e
Merge branch 'fault_tolerant_log_2/n' of https://github.com/PyTorchLi…
tchaton Jun 21, 2021
6865a36
resolve tests
tchaton Jun 21, 2021
fc92382
Merge branch 'master' into fault_tolerant_log_2/n
tchaton Jun 22, 2021
813bcea
Merge branch 'master' into fault_tolerant_log_2/n
tchaton Jun 23, 2021
fd6cf34
get duration
tchaton Jun 23, 2021
fc66c10
resolve issues
tchaton Jun 23, 2021
979bc23
resolve bug
tchaton Jun 23, 2021
effff31
update
tchaton Jun 23, 2021
b38efe1
resolve tests
tchaton Jun 23, 2021
2b70bfb
update names
tchaton Jun 23, 2021
61d46bb
resolve bug
tchaton Jun 23, 2021
67ce691
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 23, 2021
87a9d67
doc update
tchaton Jun 23, 2021
db73f1d
Merge branch 'fault_tolerant_log_2/n' of https://github.com/PyTorchLi…
tchaton Jun 23, 2021
0533200
update flake8
tchaton Jun 23, 2021
6f2b046
remove pdb
tchaton Jun 23, 2021
d831d4b
update on comments
tchaton Jun 23, 2021
eec17fb
update
tchaton Jun 23, 2021
d5db9d5
resolve test
tchaton Jun 23, 2021
1ed5eb8
format
Borda Jun 23, 2021
f7f1992
resolve tests
tchaton Jun 23, 2021
aad3fd4
Merge branch 'fault_tolerant_log_2/n' of https://github.com/PyTorchLi…
tchaton Jun 23, 2021
a722c61
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 23, 2021
45aa793
update
tchaton Jun 23, 2021
cad1559
Merge branch 'fault_tolerant_log_2/n' of https://github.com/PyTorchLi…
tchaton Jun 23, 2021
ac8df1c
update
tchaton Jun 23, 2021
4c9c0c1
update
tchaton Jun 23, 2021
66ad312
update
tchaton Jun 23, 2021
51b73c5
Merge branch 'master' into fault_tolerant_log_2/n
carmocca Jun 24, 2021
ced63b3
resolve conflicts
tchaton Jun 25, 2021
6c50332
Update CHANGELOG
carmocca Jun 25, 2021
766ef71
Docs
carmocca Jun 25, 2021
d898016
Rename metric prefix name
carmocca Jun 25, 2021
209298c
Refactor metric reset test
carmocca Jun 25, 2021
08f3c79
Typos
carmocca Jun 25, 2021
358cbd3
No need for should sync property
carmocca Jun 25, 2021
14d6f41
Decouple distributeda available
carmocca Jun 25, 2021
9471db6
Avoid deepcopy and dropping value
carmocca Jun 25, 2021
032c1f8
Remove fx validator in getstate
carmocca Jun 25, 2021
ee379cf
fx_validator shouldn't be in self.items()
carmocca Jun 25, 2021
7108190
Add reduce comment
carmocca Jun 25, 2021
eebd4d4
Improve result metrics property
carmocca Jun 25, 2021
e7bef8c
State dict wouldnt save metric attributes
carmocca Jun 25, 2021
4e49c98
Resolve circular imports to import the fx_validator
carmocca Jun 25, 2021
2c08db6
Minor changes
carmocca Jun 25, 2021
189f0ad
Revert checkpoint changes
carmocca Jun 25, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fault-tolerant training
* Add `{,load_}state_dict` to `ResultCollection` ([#7948](https://github.com/PyTorchLightning/pytorch-lightning/pull/7948))
* Checkpoint the loop results ([#7966](https://github.com/PyTorchLightning/pytorch-lightning/pull/7966))


- Add `rank_zero_only` to `LightningModule.log` function ([#7966](https://github.com/PyTorchLightning/pytorch-lightning/pull/7966))


- Added a warning if `Trainer(log_every_n_steps)` is a value too high for the training dataloader ([#7734](https://github.com/PyTorchLightning/pytorch-lightning/pull/7734))
Expand Down
17 changes: 17 additions & 0 deletions docs/source/advanced/multi_gpu.rst
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,23 @@ Note if you use any built in metrics or custom metrics that use the :doc:`Metric
# Add sync_dist=True to sync logging across all GPU workers
self.log('test_loss', loss, on_step=True, on_epoch=True, sync_dist=True)

It is possible to perform some computation manually and log the reduced result on rank 0 as follows:

.. testcode::

def test_step(self, batch, batch_idx):
x, y = batch
tensors = self(x)
return tensors

def test_epoch_end(self, outputs):
mean = torch.mean(self.all_gather(outputs))

# When logging only on rank 0, don't forget to add
# ``rank_zero_only=True`` to avoid deadlocks on synchronization.
if self.trainer.is_global_zero:
self.log("my_reduced_metric", mean, rank_zero_only=True)


Make models pickleable
^^^^^^^^^^^^^^^^^^^^^^
Expand Down
39 changes: 35 additions & 4 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,12 @@
from pytorch_lightning.core.memory import ModelSummary
from pytorch_lightning.core.optimizer import LightningOptimizer
from pytorch_lightning.core.saving import ALLOWED_CONFIG_TYPES, ModelIO, PRIMITIVE_TYPES
from pytorch_lightning.trainer.connectors.logger_connector.fx_validator import FxValidator
from pytorch_lightning.utilities import rank_zero_deprecation, rank_zero_warn
from pytorch_lightning.utilities.apply_func import apply_to_collection, convert_to_tensors
from pytorch_lightning.utilities.cloud_io import get_filesystem
from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin
from pytorch_lightning.utilities.distributed import sync_ddp_if_available
from pytorch_lightning.utilities.distributed import distributed_available, sync_ddp
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.parsing import AttributeDict, collect_init_args, save_hyperparameters
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature
Expand Down Expand Up @@ -112,6 +113,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
self._automatic_optimization: bool = True
self._truncated_bptt_steps: int = 0
self._param_requires_grad_state = dict()
self._metric_attributes: Optional[Dict[int, str]] = None

def optimizers(self, use_pl_optimizer: bool = True) -> Union[Optimizer, List[Optimizer], List[LightningOptimizer]]:
if use_pl_optimizer:
Expand Down Expand Up @@ -273,6 +275,8 @@ def log(
sync_dist_group: Optional[Any] = None,
add_dataloader_idx: bool = True,
batch_size: Optional[int] = None,
metric_attribute: Optional[str] = None,
rank_zero_only: Optional[bool] = None,
) -> None:
"""
Log a key, value
Expand Down Expand Up @@ -310,6 +314,10 @@ def log(
each dataloader to not mix values
batch_size: Current batch_size. This will be directly inferred from the loaded batch,
but some data structures might need to explicitly provide it.
metric_attribute: To restore the metric state, Lightning requires the reference of the
:class:`torchmetrics.Metric` in your model. This is found automatically if it is a model attribute.
rank_zero_only: Whether the value will be logged only on rank 0. This will prevent synchronization which
would produce a deadlock as not all processes would perform this log call.
"""
if tbptt_reduce_fx is not None:
rank_zero_deprecation(
Expand Down Expand Up @@ -346,7 +354,7 @@ def log(
results = self.trainer._results
assert results is not None
assert self._current_fx_name is not None
results.fx_validator.check_logging(self._current_fx_name, on_step=on_step, on_epoch=on_epoch)
FxValidator.check_logging(self._current_fx_name, on_step=on_step, on_epoch=on_epoch)

# make sure user doesn't introduce logic for multi-dataloaders
if "/dataloader_idx_" in name:
Expand All @@ -362,6 +370,27 @@ def log(
# reset any tensors for the new hook name
results.reset(metrics=False, fx=self._current_fx_name)

if metric_attribute is None and isinstance(value, Metric):
if self._metric_attributes is None:
# compute once
self._metric_attributes = {
id(module): name
for name, module in self.named_children() if isinstance(module, Metric)
}
if not self._metric_attributes:
raise MisconfigurationException(
"Could not find the `LightningModule` attribute for the `torchmetrics.Metric` logged."
" You can fix this by setting an attribute for the metric in your `LightningModule`."
)
# try to find the passed metric in the LightningModule
metric_attribute = self._metric_attributes.get(id(value))
if metric_attribute is None:
raise MisconfigurationException(
"Could not find the `LightningModule` attribute for the `torchmetrics.Metric` logged."
f" You can fix this by calling `self.log({name}, ..., metric_attribute=name)` where `name` is one"
f" of {list(self._metric_attributes.values())}"
)

results.log(
self._current_fx_name,
name,
Expand All @@ -374,9 +403,11 @@ def log(
enable_graph=enable_graph,
dataloader_idx=(self._current_dataloader_idx if add_dataloader_idx else None),
batch_size=batch_size,
sync_dist=sync_dist,
sync_dist_fn=self.trainer.training_type_plugin.reduce or sync_ddp_if_available,
sync_dist=sync_dist and distributed_available(),
sync_dist_fn=self.trainer.training_type_plugin.reduce or sync_ddp,
sync_dist_group=sync_dist_group,
metric_attribute=metric_attribute,
rank_zero_only=rank_zero_only,
)

self.trainer.logger_connector._current_fx = self._current_fx_name
Expand Down
6 changes: 3 additions & 3 deletions pytorch_lightning/loggers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@
import numpy as np
import torch

import pytorch_lightning as pl
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.utilities import rank_zero_only


Expand Down Expand Up @@ -300,7 +300,7 @@ def log_hyperparams(self, params: argparse.Namespace, *args, **kwargs):
kwargs: Optional keywoard arguments, depends on the specific logger being used
"""

def log_graph(self, model: LightningModule, input_array=None) -> None:
def log_graph(self, model: 'pl.LightningModule', input_array=None) -> None:
"""
Record model graph

Expand Down Expand Up @@ -396,7 +396,7 @@ def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None:
for logger in self._logger_iterable:
logger.log_hyperparams(params)

def log_graph(self, model: LightningModule, input_array=None) -> None:
def log_graph(self, model: 'pl.LightningModule', input_array=None) -> None:
for logger in self._logger_iterable:
logger.log_graph(model, input_array)

Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/loggers/comet.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import torch
from torch import is_tensor

from pytorch_lightning.core.lightning import LightningModule
import pytorch_lightning as pl
from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment
from pytorch_lightning.utilities import _module_available, rank_zero_only
from pytorch_lightning.utilities.exceptions import MisconfigurationException
Expand Down Expand Up @@ -318,6 +318,6 @@ def __getstate__(self):
state["_experiment"] = None
return state

def log_graph(self, model: LightningModule, input_array=None) -> None:
def log_graph(self, model: 'pl.LightningModule', input_array=None) -> None:
if self._experiment is not None:
self._experiment.set_model_graph(model)
4 changes: 2 additions & 2 deletions pytorch_lightning/loggers/tensorboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from torch.utils.tensorboard import SummaryWriter
from torch.utils.tensorboard.summary import hparams

from pytorch_lightning.core.lightning import LightningModule
import pytorch_lightning as pl
from pytorch_lightning.core.saving import save_hparams_to_yaml
from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment
from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, rank_zero_only, rank_zero_warn
Expand Down Expand Up @@ -223,7 +223,7 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) ->
raise ValueError(m) from ex

@rank_zero_only
def log_graph(self, model: LightningModule, input_array=None):
def log_graph(self, model: 'pl.LightningModule', input_array=None):
if self._log_graph:
if input_array is None:
input_array = model.example_input_array
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/loggers/test_tube.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from argparse import Namespace
from typing import Any, Dict, Optional, Union

from pytorch_lightning.core.lightning import LightningModule
import pytorch_lightning as pl
from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment
from pytorch_lightning.utilities import _module_available, rank_zero_warn
from pytorch_lightning.utilities.distributed import rank_zero_only
Expand Down Expand Up @@ -153,7 +153,7 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) ->
self.experiment.log(metrics, global_step=step)

@rank_zero_only
def log_graph(self, model: LightningModule, input_array=None):
def log_graph(self, model: 'pl.LightningModule', input_array=None):
if self._log_graph:
if input_array is None:
input_array = model.example_input_array
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/overrides/fairscale.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pytorch_lightning.core.lightning import LightningModule
import pytorch_lightning as pl
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase, unwrap_lightning_module
from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE

Expand All @@ -23,7 +23,7 @@ class LightningShardedDataParallel(_LightningModuleWrapperBase):
# Just do this for later docstrings
pass

def unwrap_lightning_module_sharded(wrapped_model) -> LightningModule:
def unwrap_lightning_module_sharded(wrapped_model) -> 'pl.LightningModule':
model = wrapped_model
if isinstance(model, ShardedDataParallel):
model = model.module
Expand Down
7 changes: 3 additions & 4 deletions pytorch_lightning/plugins/precision/apex_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from torch.optim import Optimizer

import pytorch_lightning as pl
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.plugins.precision.mixed import MixedPrecisionPlugin
from pytorch_lightning.utilities import _APEX_AVAILABLE, AMPType
from pytorch_lightning.utilities.types import _PARAMETERS
Expand Down Expand Up @@ -50,7 +49,7 @@ def dispatch(self, trainer: 'pl.Trainer') -> None:

def backward(
self,
model: LightningModule,
model: 'pl.LightningModule',
closure_loss: Tensor,
optimizer: Optimizer,
opt_idx: int,
Expand All @@ -76,7 +75,7 @@ def backward(

# do backward pass
# TODO: not entirely sure, why we need this
if model is not None and isinstance(model, LightningModule):
if model is not None and isinstance(model, pl.LightningModule):
model.backward(closure_loss, optimizer, opt_idx, **kwargs)

# TODO: avoid dev_debugger and track these calls with mock
Expand Down Expand Up @@ -118,7 +117,7 @@ def reinit_scheduler_properties(optimizers: Sequence[Optimizer], schedulers: Seq

def pre_optimizer_step(
self,
pl_module: LightningModule,
pl_module: 'pl.LightningModule',
optimizer: Optimizer,
optimizer_idx: int,
lambda_closure: Callable,
Expand Down
6 changes: 3 additions & 3 deletions pytorch_lightning/plugins/precision/double.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import torch.nn as nn
from torch.optim import Optimizer

from pytorch_lightning.core.lightning import LightningModule
import pytorch_lightning as pl
from pytorch_lightning.overrides.base import _LightningPrecisionModuleWrapperBase
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin
from pytorch_lightning.utilities.apply_func import apply_to_collection
Expand All @@ -33,7 +33,7 @@ class LightningDoublePrecisionModule(_LightningPrecisionModuleWrapperBase):
pl_module: the model to wrap
"""

def __init__(self, pl_module: LightningModule):
def __init__(self, pl_module: 'pl.LightningModule'):
super().__init__(pl_module)

@staticmethod
Expand Down Expand Up @@ -96,7 +96,7 @@ def connect(
incoming floating point data to double (``torch.float64``) precision. Does not alter `optimizers` or
`lr_schedulers`.
"""
model = cast(LightningModule, model.to(dtype=torch.float64))
model = cast(pl.LightningModule, model.to(dtype=torch.float64))
model = LightningDoublePrecisionModule(model)

return super().connect(model, optimizers, lr_schedulers)
Expand Down
5 changes: 4 additions & 1 deletion pytorch_lightning/plugins/training_type/ddp_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,9 @@ def transfer_distrib_spawn_state_on_fit_end(self, results):
checkpoint_callback = self.lightning_module.trainer.checkpoint_callback
best_model_path = checkpoint_callback.best_model_path if checkpoint_callback else None

# requires to compute the state_dict on all processes in case Metrics are present
state_dict = self.lightning_module.state_dict()

if self.global_rank == 0 and self.mp_queue is not None:
rank_zero_warn("cleaning up ddp environment...")

Expand All @@ -286,7 +289,7 @@ def transfer_distrib_spawn_state_on_fit_end(self, results):
and len(best_model_path) > 0
):
last_path = re.sub(".ckpt", ".tmp_end.ckpt", best_model_path)
atomic_save(self.on_save(self.lightning_module.state_dict()), last_path)
atomic_save(self.on_save(state_dict), last_path)

# todo, pass complete checkpoint as state dictionary
self.mp_queue.put(best_model_path)
Expand Down
6 changes: 3 additions & 3 deletions pytorch_lightning/plugins/training_type/ipu.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
import torch
from torch.utils.data import DataLoader

import pytorch_lightning as pl
from pytorch_lightning.callbacks import GradientAccumulationScheduler
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin
Expand All @@ -37,7 +37,7 @@

class LightningIPUModule(_LightningModuleWrapperBase):

def __init__(self, pl_module: LightningModule, precision: Union[str, int]):
def __init__(self, pl_module: 'pl.LightningModule', precision: Union[str, int]):
super().__init__(pl_module)
self.precision = precision

Expand Down Expand Up @@ -184,7 +184,7 @@ def _validate_opts(self, opts: 'poptorch.Options', training: bool) -> None:
opts.Training.set(gradient_accumulation=1)

@property
def lightning_module(self) -> Optional[LightningModule]:
def lightning_module(self) -> Optional['pl.LightningModule']:
return self.model.module if isinstance(self.model, LightningIPUModule) else self.model

def on_reset_train_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[Iterable, DataLoader]:
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/plugins/training_type/sharded.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import torch
from torch.optim import Optimizer

from pytorch_lightning.core.lightning import LightningModule
import pytorch_lightning as pl
from pytorch_lightning.core.optimizer import is_lightning_optimizer
from pytorch_lightning.plugins.training_type.ddp import DDPPlugin
from pytorch_lightning.trainer.states import TrainerFn
Expand Down Expand Up @@ -86,7 +86,7 @@ def _optim_state_dict(self, optimizer):
return optimizer.state_dict()

@property
def lightning_module(self) -> LightningModule:
def lightning_module(self) -> 'pl.LightningModule':
if not _FAIRSCALE_AVAILABLE: # pragma: no cover
raise MisconfigurationException(
"`DDPShardedPlugin` requires `fairscale` to be installed."
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/plugins/training_type/sharded_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import torch
from torch.optim import Optimizer

from pytorch_lightning.core.lightning import LightningModule
import pytorch_lightning as pl
from pytorch_lightning.plugins.precision.sharded_native_amp import ShardedNativeMixedPrecisionPlugin
from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin
from pytorch_lightning.trainer.states import TrainerFn
Expand Down Expand Up @@ -71,7 +71,7 @@ def _optim_state_dict(self, optimizer):
return optimizer.state_dict()

@property
def lightning_module(self) -> LightningModule:
def lightning_module(self) -> 'pl.LightningModule':
if not _FAIRSCALE_AVAILABLE: # pragma: no cover
raise MisconfigurationException(
"`DDPSpawnShardedPlugin` requires `fairscale` to be installed."
Expand Down
5 changes: 4 additions & 1 deletion pytorch_lightning/plugins/training_type/tpu_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,9 @@ def transfer_distrib_spawn_state_on_fit_end(self, results):
checkpoint_callback = self.lightning_module.trainer.checkpoint_callback
best_model_path = checkpoint_callback.best_model_path if checkpoint_callback else None

# requires to compute the state_dict on all processes in case Metrics are present
state_dict = self.lightning_module.state_dict()

if self.mp_queue is not None:
rank_zero_warn("cleaning up tpu spawn environment...")

Expand All @@ -195,7 +198,7 @@ def transfer_distrib_spawn_state_on_fit_end(self, results):
and len(best_model_path) > 0
):
last_path = re.sub(".ckpt", ".tmp_end.ckpt", best_model_path)
self.save(self.lightning_module.state_dict(), last_path)
self.save(state_dict, last_path)

if self.local_rank == 0:
# todo, pass complete checkpoint as state dictionary
Expand Down
Loading